1 #include "chainerx/backward_builder.h"
2 
3 #include <algorithm>
4 #include <cstddef>
5 #include <initializer_list>
6 #include <iterator>
7 #include <memory>
8 #include <tuple>
9 #include <unordered_set>
10 #include <utility>
11 #include <vector>
12 
13 #include <gsl/gsl>
14 
15 #include "chainerx/array.h"
16 #include "chainerx/array_node.h"
17 #include "chainerx/backprop_mode.h"
18 #include "chainerx/device.h"
19 #include "chainerx/graph.h"
20 #include "chainerx/macro.h"
21 #include "chainerx/op_node.h"
22 
23 namespace chainerx {
24 namespace {
25 
26 using internal::ArrayNode;
27 using internal::OpNode;
28 
29 }  // namespace
30 
Target(BackwardBuilder & builder,std::vector<size_t> input_indices)31 BackwardBuilder::Target::Target(BackwardBuilder& builder, std::vector<size_t> input_indices)
32     : builder_{builder}, input_indices_{std::move(input_indices)} {
33     // All input arrays must have the same device.
__anone1c82fc30202(size_t input_index) 34     CHAINERX_ASSERT(std::all_of(input_indices.begin(), input_indices.end(), [this](size_t input_index) {
35         return &gsl::at(builder_.inputs_, input_index).get().device() == &(builder_.inputs_.front().get().device());
36     }));
37 
38     graph_to_input_array_nodes_ = CreateInputArrayNodesMap();
39 }
40 
CreateInputArrayNodesMap() const41 absl::flat_hash_map<BackpropId, BackwardBuilder::Target::InputArrayNodes> BackwardBuilder::Target::CreateInputArrayNodesMap() const {
42     absl::flat_hash_map<BackpropId, InputArrayNodes> graph_to_input_array_nodes{};
43 
44     if (builder_.has_any_applicable_outputs_) {
45         // At least one output arrays can have gradients.
46 
47         for (size_t input_index : input_indices_) {
48             // Need to access the input array via the builder.
49             const Array& input = gsl::at(builder_.inputs_, input_index);
50 
51             for (std::shared_ptr<ArrayNode>& input_array_node : internal::GetArrayBody(input)->nodes()) {
52                 const BackpropId& backprop_id = input_array_node->backprop_id();
53                 if (!IsBackpropRequired(backprop_id)) {
54                     continue;
55                 }
56 
57                 // Add the array node to the mapping
58                 auto insert_result = graph_to_input_array_nodes.emplace(backprop_id, InputArrayNodes{});
59                 auto& input_array_nodes = insert_result.first->second;
60                 if (insert_result.second) {
61                     // New array node for a graph. Fill all array nodes with nullptr.
62                     input_array_nodes.resize(builder_.inputs_.size());
63                 }
64                 // Assign valid pointer to the array node.
65                 input_array_nodes[input_index] = &input_array_node;
66             }
67         }
68 
69         if (CHAINERX_DEBUG) {
70             for (auto& pair : graph_to_input_array_nodes) {
71                 const BackpropId& backprop_id = pair.first;
72                 (void)backprop_id;  // maybe unused
73                 const InputArrayNodes& input_array_nodes = pair.second;
74                 for (const std::shared_ptr<ArrayNode>* array_node : input_array_nodes) {
75                     CHAINERX_ASSERT(array_node == nullptr || backprop_id == (*array_node)->backprop_id());
76                 }
77             }
78         }
79     }
80 
81     return graph_to_input_array_nodes;
82 }
83 
Define(const BackwardFunction & backward_func)84 void BackwardBuilder::Target::Define(const BackwardFunction& backward_func) {
85     CHAINERX_ASSERT(is_definition_required());
86 
87     // Find/Create an op node for each graph and register the given backward function to each of them.
88     for (const auto& pair : graph_to_input_array_nodes_) {
89         const BackpropId& backprop_id = pair.first;
90         const InputArrayNodes& input_array_nodes = pair.second;
91 
92         std::vector<std::tuple<size_t, std::shared_ptr<ArrayNode>>> temp_input_array_nodes;
93         temp_input_array_nodes.reserve(input_array_nodes.size());
94         std::transform(
95                 input_indices_.begin(),
96                 input_indices_.end(),
97                 std::back_inserter(temp_input_array_nodes),
98                 [&input_array_nodes](size_t input_index) {
99                     const std::shared_ptr<ArrayNode>* array_node = input_array_nodes[input_index];
100                     return std::make_tuple(input_index, array_node == nullptr ? nullptr : *array_node);
101                 });
102 
103         std::shared_ptr<OpNode>& op_node = builder_.FindOrCreateOpNode(backprop_id);
104         op_node->RegisterBackwardFunction(std::move(temp_input_array_nodes), backward_func);
105     }
106 }
107 
BackwardBuilder(const char * op_name,std::vector<ConstArrayRef> inputs,std::vector<ConstArrayRef> outputs)108 BackwardBuilder::BackwardBuilder(const char* op_name, std::vector<ConstArrayRef> inputs, std::vector<ConstArrayRef> outputs)
109     : op_name_{op_name},
110       context_{inputs.front().get().context()},
111       inputs_{std::move(inputs)},
112       inputs_target_created_(inputs_.size()),
113       outputs_{std::move(outputs)},
114       input_retention_record_{inputs_.size()},
115       output_retention_record_{outputs_.size()} {
116     CHAINERX_ASSERT(!inputs_.empty());
117     CHAINERX_ASSERT(!outputs_.empty());
118     CHAINERX_ASSERT(inputs_.size() == inputs_target_created_.size());
119     CHAINERX_ASSERT(
__anone1c82fc30402(const Array& input) 120             std::all_of(inputs_.begin(), inputs_.end(), [](const Array& input) { return internal::GetArrayBody(input) != nullptr; }));
121     CHAINERX_ASSERT(
__anone1c82fc30502(const Array& output) 122             std::all_of(outputs_.begin(), outputs_.end(), [](const Array& output) { return internal::GetArrayBody(output) != nullptr; }));
123     // Outputs requiring grad (e.g. in-place ops.) must have been detected and reported before reaching here.
124     CHAINERX_ASSERT(std::all_of(
__anone1c82fc30602(const Array& output) 125             outputs_.begin(), outputs_.end(), [](const Array& output) { return internal::GetArrayBody(output)->nodes().empty(); }));
126     // Arrays must be on the same device within inputs / outputs respectively.
__anone1c82fc30702(const Array& output) 127     CHAINERX_ASSERT(std::all_of(outputs_.begin(), outputs_.end(), [this](const Array& output) {
128         return &outputs_.begin()->get().device() == &output.device();
129     }));
130     CHAINERX_ASSERT(std::all_of(
__anone1c82fc30802(const Array& input) 131             inputs_.begin(), inputs_.end(), [this](const Array& input) { return &inputs_.begin()->get().device() == &input.device(); }));
132 
133     has_any_applicable_outputs_ =
__anone1c82fc30902(const Array& output) 134             std::any_of(outputs_.begin(), outputs_.end(), [](const Array& output) { return GetKind(output.dtype()) == DtypeKind::kFloat; });
135 }
136 
FindOrCreateOpNode(const BackpropId & backprop_id)137 std::shared_ptr<OpNode>& BackwardBuilder::FindOrCreateOpNode(const BackpropId& backprop_id) {
138     // Try to find an existing op node for the given graph.
139     auto insert_result = op_node_map_.emplace(backprop_id, nullptr);
140 
141     // If not found, create a new one.
142     if (insert_result.second) {
143         insert_result.first->second = OpNode::CreateWithOutputArrayNodes(op_name_, backprop_id, inputs_.size(), outputs_);
144     }
145 
146     CHAINERX_ASSERT(!op_node_map_.empty());
147     return insert_result.first->second;
148 }
149 
RetainInput(size_t input_index)150 RetainedInputToken BackwardBuilder::RetainInput(size_t input_index) {
151     CHAINERX_ASSERT(input_index < inputs_.size());
152     input_retention_record_.Record(input_index);
153     return {internal::GetArrayBody(gsl::at(inputs_, input_index))->GetParams(), input_index};
154 }
155 
RetainInput(std::vector<size_t> indices)156 std::vector<RetainedInputToken> BackwardBuilder::RetainInput(std::vector<size_t> indices) {
157     std::vector<RetainedInputToken> token;
158     for (size_t i : indices) {
159         CHAINERX_ASSERT(i < inputs_.size());
160         input_retention_record_.Record(i);
161         token.emplace_back(internal::GetArrayBody(gsl::at(inputs_, i))->GetParams(), i);
162     }
163     return token;
164 }
165 
RetainOutput(size_t output_index)166 RetainedOutputToken BackwardBuilder::RetainOutput(size_t output_index) {
167     CHAINERX_ASSERT(output_index < outputs_.size());
168     output_retention_record_.Record(output_index);
169     return {internal::GetArrayBody(gsl::at(outputs_, output_index))->GetParams(), output_index};
170 }
171 
RetainOutput(std::vector<size_t> indices)172 std::vector<RetainedOutputToken> BackwardBuilder::RetainOutput(std::vector<size_t> indices) {
173     std::vector<RetainedOutputToken> token;
174     for (size_t i : indices) {
175         CHAINERX_ASSERT(i < outputs_.size());
176         output_retention_record_.Record(i);
177         token.emplace_back(internal::GetArrayBody(gsl::at(outputs_, i))->GetParams(), i);
178     }
179     return token;
180 }
181 
Finalize()182 void BackwardBuilder::Finalize() {
183     CHAINERX_ASSERT(!is_finalized_);
184     // Checks that the backward definitions cover all the input arrays.
185     CHAINERX_ASSERT(std::all_of(inputs_target_created_.begin(), inputs_target_created_.end(), [](bool done) { return done; }));
186 
187     AddEdgesFromOpNodeToArrayNodeOfOuterGraphsForRetention();
188 
189     // Connect each pair of backprop IDs concerned in this op.
190     // If two backprop IDs are connected, backpropping on the one with lower ordinal will prohibit future backprop on the other.
191     ConnectBackpropIds();
192 
193     is_finalized_ = true;
194 }
195 
196 namespace {
197 
AddEdgesFromOpNodeToInputArrayNodesOfOuterGraph(const OpNode & outer_op_node,OpNode & inner_op_node,const backward_builder_detail::RetentionRecord & input_retention_record)198 void AddEdgesFromOpNodeToInputArrayNodesOfOuterGraph(
199         const OpNode& outer_op_node, OpNode& inner_op_node, const backward_builder_detail::RetentionRecord& input_retention_record) {
200     std::vector<std::shared_ptr<internal::ArrayNode>> input_array_nodes;
201     input_array_nodes.reserve(input_retention_record.size());
202 
203     for (size_t i = 0; i < input_retention_record.size(); ++i) {
204         if (input_retention_record.IsRecorded(i)) {
205             input_array_nodes.emplace_back(outer_op_node.input_array_nodes()[i]);
206         } else {
207             input_array_nodes.emplace_back(nullptr);
208         }
209     }
210 
211     inner_op_node.AddEdgesToInputArrayNodesOfOuterGraph(outer_op_node.backprop_id(), std::move(input_array_nodes));
212 }
213 
AddEdgesFromOpNodeToOutputArrayNodesOfOuterGraph(const OpNode & outer_op_node,OpNode & inner_op_node,const backward_builder_detail::RetentionRecord & output_retention_record)214 void AddEdgesFromOpNodeToOutputArrayNodesOfOuterGraph(
215         const OpNode& outer_op_node, OpNode& inner_op_node, const backward_builder_detail::RetentionRecord& output_retention_record) {
216     std::vector<std::shared_ptr<internal::ArrayNode>> output_array_nodes;
217     output_array_nodes.reserve(output_retention_record.size());
218 
219     // Outer graphs must be registered using shared_ptr but op nodes only have weak_ptr to their output array nodes.
220     // Therefore, first convert the weak_ptr to shared_ptr, assuming that they have not expired.
221     for (size_t i = 0; i < output_retention_record.size(); ++i) {
222         if (output_retention_record.IsRecorded(i)) {
223             const absl::optional<std::weak_ptr<ArrayNode>>& array_node = outer_op_node.output_array_nodes()[i];
224             CHAINERX_ASSERT(array_node.has_value());
225             CHAINERX_ASSERT(!array_node->expired());
226             output_array_nodes.emplace_back(array_node->lock());
227         } else {
228             output_array_nodes.emplace_back(nullptr);
229         }
230     }
231 
232     inner_op_node.AddEdgesToOutputArrayNodesOfOuterGraph(outer_op_node.backprop_id(), std::move(output_array_nodes));
233 }
234 
235 }  // namespace
236 
AddEdgesFromOpNodeToArrayNodeOfOuterGraphsForRetention()237 void BackwardBuilder::AddEdgesFromOpNodeToArrayNodeOfOuterGraphsForRetention() {
238     // Create edges from op nodes to outer graph array nodes so that retained inputs and output can be restored.
239     // For outputs, we need to consider all graphs that this builder defines since each output participates in all of them.
240     // For inputs, we only need to consider a subset of the graphs; the graphs that each input belongs to.
241 
242     // Add edges to input array nodes
243     if (input_retention_record_.IsAnyRecorded()) {
244         // Collect graphs to which the retained inputs belong.
245         // TODO(beam2d): Use a lighter container.
246         std::unordered_set<BackpropId> retained_graphs{};
247         for (size_t i = 0; i < input_retention_record_.size(); ++i) {
248             if (input_retention_record_.IsRecorded(i)) {
249                 for (const std::shared_ptr<ArrayNode>& array_node : internal::GetArrayBody(gsl::at(inputs_, i))->nodes()) {
250                     retained_graphs.emplace(array_node->backprop_id());
251                 }
252             }
253         }
254 
255         // Add edges to the input array nodes belonging to the collected graphs.
256         for (const BackpropId& backprop_id : retained_graphs) {
257             const OpNode& op_node = *op_node_map_.at(backprop_id);
258             for (const BackpropId& other_backprop_id : retained_graphs) {
259                 if (backprop_id < other_backprop_id) {
260                     OpNode& other_op_node = *op_node_map_.at(other_backprop_id);
261                     AddEdgesFromOpNodeToInputArrayNodesOfOuterGraph(op_node, other_op_node, input_retention_record_);
262                 }
263             }
264         }
265     }
266 
267     // Add edges to output array nodes
268     if (output_retention_record_.IsAnyRecorded()) {
269         for (const auto& tup : op_node_map_) {
270             const BackpropId& backprop_id = tup.first;
271             const OpNode& op_node = *tup.second;
272             for (const auto& other_tup : op_node_map_) {
273                 const BackpropId& other_backprop_id = other_tup.first;
274                 OpNode& other_op_node = *other_tup.second;
275                 if (backprop_id < other_backprop_id) {
276                     AddEdgesFromOpNodeToOutputArrayNodesOfOuterGraph(op_node, other_op_node, output_retention_record_);
277                 }
278             }
279         }
280     }
281 }
282 
ConnectBackpropIds()283 void BackwardBuilder::ConnectBackpropIds() {
284     for (auto it1 = op_node_map_.begin(); it1 != op_node_map_.end(); ++it1) {
285         const BackpropId& backprop_id1 = it1->first;
286         for (auto it2 = std::next(it1); it2 != op_node_map_.end(); ++it2) {
287             const BackpropId& backprop_id2 = it2->first;
288             context_.ConnectBackpropIds(backprop_id1, backprop_id2);
289         }
290     }
291 }
292 
293 }  // namespace chainerx
294