1 #include "chainerx/backward_builder.h"
2
3 #include <algorithm>
4 #include <cstddef>
5 #include <initializer_list>
6 #include <iterator>
7 #include <memory>
8 #include <tuple>
9 #include <unordered_set>
10 #include <utility>
11 #include <vector>
12
13 #include <gsl/gsl>
14
15 #include "chainerx/array.h"
16 #include "chainerx/array_node.h"
17 #include "chainerx/backprop_mode.h"
18 #include "chainerx/device.h"
19 #include "chainerx/graph.h"
20 #include "chainerx/macro.h"
21 #include "chainerx/op_node.h"
22
23 namespace chainerx {
24 namespace {
25
26 using internal::ArrayNode;
27 using internal::OpNode;
28
29 } // namespace
30
Target(BackwardBuilder & builder,std::vector<size_t> input_indices)31 BackwardBuilder::Target::Target(BackwardBuilder& builder, std::vector<size_t> input_indices)
32 : builder_{builder}, input_indices_{std::move(input_indices)} {
33 // All input arrays must have the same device.
__anone1c82fc30202(size_t input_index) 34 CHAINERX_ASSERT(std::all_of(input_indices.begin(), input_indices.end(), [this](size_t input_index) {
35 return &gsl::at(builder_.inputs_, input_index).get().device() == &(builder_.inputs_.front().get().device());
36 }));
37
38 graph_to_input_array_nodes_ = CreateInputArrayNodesMap();
39 }
40
CreateInputArrayNodesMap() const41 absl::flat_hash_map<BackpropId, BackwardBuilder::Target::InputArrayNodes> BackwardBuilder::Target::CreateInputArrayNodesMap() const {
42 absl::flat_hash_map<BackpropId, InputArrayNodes> graph_to_input_array_nodes{};
43
44 if (builder_.has_any_applicable_outputs_) {
45 // At least one output arrays can have gradients.
46
47 for (size_t input_index : input_indices_) {
48 // Need to access the input array via the builder.
49 const Array& input = gsl::at(builder_.inputs_, input_index);
50
51 for (std::shared_ptr<ArrayNode>& input_array_node : internal::GetArrayBody(input)->nodes()) {
52 const BackpropId& backprop_id = input_array_node->backprop_id();
53 if (!IsBackpropRequired(backprop_id)) {
54 continue;
55 }
56
57 // Add the array node to the mapping
58 auto insert_result = graph_to_input_array_nodes.emplace(backprop_id, InputArrayNodes{});
59 auto& input_array_nodes = insert_result.first->second;
60 if (insert_result.second) {
61 // New array node for a graph. Fill all array nodes with nullptr.
62 input_array_nodes.resize(builder_.inputs_.size());
63 }
64 // Assign valid pointer to the array node.
65 input_array_nodes[input_index] = &input_array_node;
66 }
67 }
68
69 if (CHAINERX_DEBUG) {
70 for (auto& pair : graph_to_input_array_nodes) {
71 const BackpropId& backprop_id = pair.first;
72 (void)backprop_id; // maybe unused
73 const InputArrayNodes& input_array_nodes = pair.second;
74 for (const std::shared_ptr<ArrayNode>* array_node : input_array_nodes) {
75 CHAINERX_ASSERT(array_node == nullptr || backprop_id == (*array_node)->backprop_id());
76 }
77 }
78 }
79 }
80
81 return graph_to_input_array_nodes;
82 }
83
Define(const BackwardFunction & backward_func)84 void BackwardBuilder::Target::Define(const BackwardFunction& backward_func) {
85 CHAINERX_ASSERT(is_definition_required());
86
87 // Find/Create an op node for each graph and register the given backward function to each of them.
88 for (const auto& pair : graph_to_input_array_nodes_) {
89 const BackpropId& backprop_id = pair.first;
90 const InputArrayNodes& input_array_nodes = pair.second;
91
92 std::vector<std::tuple<size_t, std::shared_ptr<ArrayNode>>> temp_input_array_nodes;
93 temp_input_array_nodes.reserve(input_array_nodes.size());
94 std::transform(
95 input_indices_.begin(),
96 input_indices_.end(),
97 std::back_inserter(temp_input_array_nodes),
98 [&input_array_nodes](size_t input_index) {
99 const std::shared_ptr<ArrayNode>* array_node = input_array_nodes[input_index];
100 return std::make_tuple(input_index, array_node == nullptr ? nullptr : *array_node);
101 });
102
103 std::shared_ptr<OpNode>& op_node = builder_.FindOrCreateOpNode(backprop_id);
104 op_node->RegisterBackwardFunction(std::move(temp_input_array_nodes), backward_func);
105 }
106 }
107
BackwardBuilder(const char * op_name,std::vector<ConstArrayRef> inputs,std::vector<ConstArrayRef> outputs)108 BackwardBuilder::BackwardBuilder(const char* op_name, std::vector<ConstArrayRef> inputs, std::vector<ConstArrayRef> outputs)
109 : op_name_{op_name},
110 context_{inputs.front().get().context()},
111 inputs_{std::move(inputs)},
112 inputs_target_created_(inputs_.size()),
113 outputs_{std::move(outputs)},
114 input_retention_record_{inputs_.size()},
115 output_retention_record_{outputs_.size()} {
116 CHAINERX_ASSERT(!inputs_.empty());
117 CHAINERX_ASSERT(!outputs_.empty());
118 CHAINERX_ASSERT(inputs_.size() == inputs_target_created_.size());
119 CHAINERX_ASSERT(
__anone1c82fc30402(const Array& input) 120 std::all_of(inputs_.begin(), inputs_.end(), [](const Array& input) { return internal::GetArrayBody(input) != nullptr; }));
121 CHAINERX_ASSERT(
__anone1c82fc30502(const Array& output) 122 std::all_of(outputs_.begin(), outputs_.end(), [](const Array& output) { return internal::GetArrayBody(output) != nullptr; }));
123 // Outputs requiring grad (e.g. in-place ops.) must have been detected and reported before reaching here.
124 CHAINERX_ASSERT(std::all_of(
__anone1c82fc30602(const Array& output) 125 outputs_.begin(), outputs_.end(), [](const Array& output) { return internal::GetArrayBody(output)->nodes().empty(); }));
126 // Arrays must be on the same device within inputs / outputs respectively.
__anone1c82fc30702(const Array& output) 127 CHAINERX_ASSERT(std::all_of(outputs_.begin(), outputs_.end(), [this](const Array& output) {
128 return &outputs_.begin()->get().device() == &output.device();
129 }));
130 CHAINERX_ASSERT(std::all_of(
__anone1c82fc30802(const Array& input) 131 inputs_.begin(), inputs_.end(), [this](const Array& input) { return &inputs_.begin()->get().device() == &input.device(); }));
132
133 has_any_applicable_outputs_ =
__anone1c82fc30902(const Array& output) 134 std::any_of(outputs_.begin(), outputs_.end(), [](const Array& output) { return GetKind(output.dtype()) == DtypeKind::kFloat; });
135 }
136
FindOrCreateOpNode(const BackpropId & backprop_id)137 std::shared_ptr<OpNode>& BackwardBuilder::FindOrCreateOpNode(const BackpropId& backprop_id) {
138 // Try to find an existing op node for the given graph.
139 auto insert_result = op_node_map_.emplace(backprop_id, nullptr);
140
141 // If not found, create a new one.
142 if (insert_result.second) {
143 insert_result.first->second = OpNode::CreateWithOutputArrayNodes(op_name_, backprop_id, inputs_.size(), outputs_);
144 }
145
146 CHAINERX_ASSERT(!op_node_map_.empty());
147 return insert_result.first->second;
148 }
149
RetainInput(size_t input_index)150 RetainedInputToken BackwardBuilder::RetainInput(size_t input_index) {
151 CHAINERX_ASSERT(input_index < inputs_.size());
152 input_retention_record_.Record(input_index);
153 return {internal::GetArrayBody(gsl::at(inputs_, input_index))->GetParams(), input_index};
154 }
155
RetainInput(std::vector<size_t> indices)156 std::vector<RetainedInputToken> BackwardBuilder::RetainInput(std::vector<size_t> indices) {
157 std::vector<RetainedInputToken> token;
158 for (size_t i : indices) {
159 CHAINERX_ASSERT(i < inputs_.size());
160 input_retention_record_.Record(i);
161 token.emplace_back(internal::GetArrayBody(gsl::at(inputs_, i))->GetParams(), i);
162 }
163 return token;
164 }
165
RetainOutput(size_t output_index)166 RetainedOutputToken BackwardBuilder::RetainOutput(size_t output_index) {
167 CHAINERX_ASSERT(output_index < outputs_.size());
168 output_retention_record_.Record(output_index);
169 return {internal::GetArrayBody(gsl::at(outputs_, output_index))->GetParams(), output_index};
170 }
171
RetainOutput(std::vector<size_t> indices)172 std::vector<RetainedOutputToken> BackwardBuilder::RetainOutput(std::vector<size_t> indices) {
173 std::vector<RetainedOutputToken> token;
174 for (size_t i : indices) {
175 CHAINERX_ASSERT(i < outputs_.size());
176 output_retention_record_.Record(i);
177 token.emplace_back(internal::GetArrayBody(gsl::at(outputs_, i))->GetParams(), i);
178 }
179 return token;
180 }
181
Finalize()182 void BackwardBuilder::Finalize() {
183 CHAINERX_ASSERT(!is_finalized_);
184 // Checks that the backward definitions cover all the input arrays.
185 CHAINERX_ASSERT(std::all_of(inputs_target_created_.begin(), inputs_target_created_.end(), [](bool done) { return done; }));
186
187 AddEdgesFromOpNodeToArrayNodeOfOuterGraphsForRetention();
188
189 // Connect each pair of backprop IDs concerned in this op.
190 // If two backprop IDs are connected, backpropping on the one with lower ordinal will prohibit future backprop on the other.
191 ConnectBackpropIds();
192
193 is_finalized_ = true;
194 }
195
196 namespace {
197
AddEdgesFromOpNodeToInputArrayNodesOfOuterGraph(const OpNode & outer_op_node,OpNode & inner_op_node,const backward_builder_detail::RetentionRecord & input_retention_record)198 void AddEdgesFromOpNodeToInputArrayNodesOfOuterGraph(
199 const OpNode& outer_op_node, OpNode& inner_op_node, const backward_builder_detail::RetentionRecord& input_retention_record) {
200 std::vector<std::shared_ptr<internal::ArrayNode>> input_array_nodes;
201 input_array_nodes.reserve(input_retention_record.size());
202
203 for (size_t i = 0; i < input_retention_record.size(); ++i) {
204 if (input_retention_record.IsRecorded(i)) {
205 input_array_nodes.emplace_back(outer_op_node.input_array_nodes()[i]);
206 } else {
207 input_array_nodes.emplace_back(nullptr);
208 }
209 }
210
211 inner_op_node.AddEdgesToInputArrayNodesOfOuterGraph(outer_op_node.backprop_id(), std::move(input_array_nodes));
212 }
213
AddEdgesFromOpNodeToOutputArrayNodesOfOuterGraph(const OpNode & outer_op_node,OpNode & inner_op_node,const backward_builder_detail::RetentionRecord & output_retention_record)214 void AddEdgesFromOpNodeToOutputArrayNodesOfOuterGraph(
215 const OpNode& outer_op_node, OpNode& inner_op_node, const backward_builder_detail::RetentionRecord& output_retention_record) {
216 std::vector<std::shared_ptr<internal::ArrayNode>> output_array_nodes;
217 output_array_nodes.reserve(output_retention_record.size());
218
219 // Outer graphs must be registered using shared_ptr but op nodes only have weak_ptr to their output array nodes.
220 // Therefore, first convert the weak_ptr to shared_ptr, assuming that they have not expired.
221 for (size_t i = 0; i < output_retention_record.size(); ++i) {
222 if (output_retention_record.IsRecorded(i)) {
223 const absl::optional<std::weak_ptr<ArrayNode>>& array_node = outer_op_node.output_array_nodes()[i];
224 CHAINERX_ASSERT(array_node.has_value());
225 CHAINERX_ASSERT(!array_node->expired());
226 output_array_nodes.emplace_back(array_node->lock());
227 } else {
228 output_array_nodes.emplace_back(nullptr);
229 }
230 }
231
232 inner_op_node.AddEdgesToOutputArrayNodesOfOuterGraph(outer_op_node.backprop_id(), std::move(output_array_nodes));
233 }
234
235 } // namespace
236
AddEdgesFromOpNodeToArrayNodeOfOuterGraphsForRetention()237 void BackwardBuilder::AddEdgesFromOpNodeToArrayNodeOfOuterGraphsForRetention() {
238 // Create edges from op nodes to outer graph array nodes so that retained inputs and output can be restored.
239 // For outputs, we need to consider all graphs that this builder defines since each output participates in all of them.
240 // For inputs, we only need to consider a subset of the graphs; the graphs that each input belongs to.
241
242 // Add edges to input array nodes
243 if (input_retention_record_.IsAnyRecorded()) {
244 // Collect graphs to which the retained inputs belong.
245 // TODO(beam2d): Use a lighter container.
246 std::unordered_set<BackpropId> retained_graphs{};
247 for (size_t i = 0; i < input_retention_record_.size(); ++i) {
248 if (input_retention_record_.IsRecorded(i)) {
249 for (const std::shared_ptr<ArrayNode>& array_node : internal::GetArrayBody(gsl::at(inputs_, i))->nodes()) {
250 retained_graphs.emplace(array_node->backprop_id());
251 }
252 }
253 }
254
255 // Add edges to the input array nodes belonging to the collected graphs.
256 for (const BackpropId& backprop_id : retained_graphs) {
257 const OpNode& op_node = *op_node_map_.at(backprop_id);
258 for (const BackpropId& other_backprop_id : retained_graphs) {
259 if (backprop_id < other_backprop_id) {
260 OpNode& other_op_node = *op_node_map_.at(other_backprop_id);
261 AddEdgesFromOpNodeToInputArrayNodesOfOuterGraph(op_node, other_op_node, input_retention_record_);
262 }
263 }
264 }
265 }
266
267 // Add edges to output array nodes
268 if (output_retention_record_.IsAnyRecorded()) {
269 for (const auto& tup : op_node_map_) {
270 const BackpropId& backprop_id = tup.first;
271 const OpNode& op_node = *tup.second;
272 for (const auto& other_tup : op_node_map_) {
273 const BackpropId& other_backprop_id = other_tup.first;
274 OpNode& other_op_node = *other_tup.second;
275 if (backprop_id < other_backprop_id) {
276 AddEdgesFromOpNodeToOutputArrayNodesOfOuterGraph(op_node, other_op_node, output_retention_record_);
277 }
278 }
279 }
280 }
281 }
282
ConnectBackpropIds()283 void BackwardBuilder::ConnectBackpropIds() {
284 for (auto it1 = op_node_map_.begin(); it1 != op_node_map_.end(); ++it1) {
285 const BackpropId& backprop_id1 = it1->first;
286 for (auto it2 = std::next(it1); it2 != op_node_map_.end(); ++it2) {
287 const BackpropId& backprop_id2 = it2->first;
288 context_.ConnectBackpropIds(backprop_id1, backprop_id2);
289 }
290 }
291 }
292
293 } // namespace chainerx
294