1 #pragma once
2 
3 #include <cstddef>
4 #include <cstdint>
5 #include <functional>
6 #include <initializer_list>
7 #include <memory>
8 #include <numeric>
9 #include <set>
10 #include <utility>
11 #include <vector>
12 
13 #include <absl/container/flat_hash_map.h>
14 #include <gsl/gsl>
15 
16 #include "chainerx/array.h"
17 #include "chainerx/array_body.h"
18 #include "chainerx/array_node.h"
19 #include "chainerx/constant.h"
20 #include "chainerx/device.h"
21 #include "chainerx/dtype.h"
22 #include "chainerx/graph.h"
23 #include "chainerx/macro.h"
24 #include "chainerx/op_node.h"
25 #include "chainerx/shape.h"
26 
27 namespace chainerx {
28 namespace backward_builder_detail {
29 
30 // This class is used by the BackwardBuilder to record retained inputs and outputs.
31 // The records are used to create outer graph edges (between op nodes and previous array nodes) when the builder is finalized.
32 class RetentionRecord {
33 public:
RetentionRecord(size_t size)34     explicit RetentionRecord(size_t size) : size_{size} { CHAINERX_ASSERT(size_ > 0); }
35 
size()36     size_t size() const { return size_; }
37 
Record(size_t index)38     void Record(size_t index) {
39         if (flags_.empty()) {
40             flags_.resize(size_);
41         }
42         gsl::at(flags_, index) = static_cast<int8_t>(true);
43     }
44 
IsAnyRecorded()45     bool IsAnyRecorded() const { return !flags_.empty(); }
46 
IsRecorded(size_t index)47     bool IsRecorded(size_t index) const { return static_cast<bool>(flags_[index]); }
48 
49 private:
50     size_t size_{};
51     std::vector<int8_t> flags_{};  // binary flags
52 };
53 
54 template <typename Tag>
55 class RetainedArrayToken {
56 public:
RetainedArrayToken(internal::ArrayBody::Params array_params,size_t index)57     RetainedArrayToken(internal::ArrayBody::Params array_params, size_t index) : array_params_{std::move(array_params)}, index_{index} {}
58 
59     ~RetainedArrayToken() = default;
60 
61     RetainedArrayToken(const RetainedArrayToken&) = default;
62     RetainedArrayToken(RetainedArrayToken&&) noexcept = default;
63     RetainedArrayToken& operator=(const RetainedArrayToken&) = default;
64     // TODO(hvy): Make the move assignment operator noexcept.
65     RetainedArrayToken& operator=(RetainedArrayToken&&) = default;  // NOLINT(performance-noexcept-move-constructor)
66 
67 private:
68     friend class chainerx::BackwardContext;
69 
70     // Returns the array index.
index()71     size_t index() const { return index_; }
72 
array_params()73     const internal::ArrayBody::Params& array_params() const { return array_params_; }
74 
75     internal::ArrayBody::Params array_params_;
76 
77     size_t index_;
78 };
79 
80 }  // namespace backward_builder_detail
81 
82 // An object used by op implementations to bridge between BackwardBuilder::RetainInput() and BackwardContext::GetRetainedInput().
83 //
84 // See BackwardBuilder::RetainInput() for details.
85 using RetainedInputToken = backward_builder_detail::RetainedArrayToken<struct InputTag>;
86 
87 // An object used by op implementations to bridge between BackwardBuilder::RetainOutput() and BackwardContext::GetRetainedOutput().
88 //
89 // See BackwardBuilder::RetainOutput() for details.
90 using RetainedOutputToken = backward_builder_detail::RetainedArrayToken<struct OutputTag>;
91 
92 // A class that is used to define backward operations and connect the graph.
93 //
94 // This class is not thread safe.
95 class BackwardBuilder {
96 public:
97     // Target is responsible to define edges from OpNode to input ArrayNodes with given BackwardFunction.
98     // Note that Targets built from the same BackwardBuilder share some properties not to compute again.
99     class Target {
100     public:
101         explicit operator bool() const { return is_definition_required(); }
102 
103         // Defines a backward function with respect to specified input arrays (target).
104         void Define(const BackwardFunction& backward_func);
105 
is_definition_required()106         bool is_definition_required() const { return !graph_to_input_array_nodes_.empty(); }
107 
108     private:
109         friend class BackwardBuilder;  // Only BackwardBuilder can create Target
110 
111         using InputArrayNodes = std::vector<const std::shared_ptr<internal::ArrayNode>*>;
112 
113         Target(BackwardBuilder& builder, std::vector<size_t> input_indices);
114 
115         // Collect input ArrayNodes, grouped by graph considering IsBackpropRequired.
116         // This functions is only called once in the constructor.
117         absl::flat_hash_map<BackpropId, InputArrayNodes> CreateInputArrayNodesMap() const;
118 
119         BackwardBuilder& builder_;
120         std::vector<size_t> input_indices_;
121 
122         // TODO(hvy): Consider using linear search since elements are usually few.
123         absl::flat_hash_map<BackpropId, InputArrayNodes> graph_to_input_array_nodes_;
124     };
125 
126     // TODO(niboshi): Add an overload to accept `const std::vector<Array>&` as `inputs` and `outputs`
127     // Note that simply overloading with the above type will results in ambiguous calls.
128     // One solution is to define a type that accepts all of the expected types of inputs.
129     BackwardBuilder(const char* op_name, std::vector<ConstArrayRef> inputs, std::vector<ConstArrayRef> outputs);
BackwardBuilder(const char * op_name,const Array & input,std::vector<ConstArrayRef> outputs)130     BackwardBuilder(const char* op_name, const Array& input, std::vector<ConstArrayRef> outputs)
131         : BackwardBuilder{op_name, std::vector<ConstArrayRef>{input}, std::move(outputs)} {}
BackwardBuilder(const char * op_name,std::vector<ConstArrayRef> inputs,const Array & output)132     BackwardBuilder(const char* op_name, std::vector<ConstArrayRef> inputs, const Array& output)
133         : BackwardBuilder{op_name, std::move(inputs), std::vector<ConstArrayRef>{output}} {}
BackwardBuilder(const char * op_name,const Array & input,const Array & output)134     BackwardBuilder(const char* op_name, const Array& input, const Array& output)
135         : BackwardBuilder{op_name, std::vector<ConstArrayRef>{input}, std::vector<ConstArrayRef>{output}} {}
~BackwardBuilder()136     ~BackwardBuilder() { CHAINERX_ASSERT(is_finalized_); }
137 
138     BackwardBuilder(const BackwardBuilder&) = delete;
139     BackwardBuilder(BackwardBuilder&&) noexcept = default;
140     BackwardBuilder& operator=(const BackwardBuilder&) = delete;
141     BackwardBuilder& operator=(BackwardBuilder&&) = delete;
142 
143     // Creates a backward target for the specified inputs.
CreateTarget(std::vector<size_t> input_indices)144     Target CreateTarget(std::vector<size_t> input_indices) {
145         // input_indices shouldn't have duplicates.
146         CHAINERX_ASSERT((std::set<size_t>{input_indices.begin(), input_indices.end()}.size() == input_indices.size()));
147 
148         for (size_t input_index : input_indices) {
149             CHAINERX_ASSERT(input_index < inputs_target_created_.size());
150             CHAINERX_ASSERT(!inputs_target_created_[input_index]);
151             inputs_target_created_[input_index] = true;
152         }
153         return Target{*this, std::move(input_indices)};
154     }
155 
156     // Creates a backward target for the specified input.
CreateTarget(size_t input_index)157     Target CreateTarget(size_t input_index) { return CreateTarget(std::vector<size_t>{input_index}); }
158 
159     // Creates a backward target for all the inputs.
CreateTarget()160     Target CreateTarget() {
161         std::vector<size_t> input_indices;
162         input_indices.resize(inputs_.size());
163         std::iota(input_indices.begin(), input_indices.end(), size_t{0});
164 
165         return CreateTarget(std::move(input_indices));
166     }
167 
168     // TODO(hvy): Write comment.
169     RetainedInputToken RetainInput(size_t input_index);
170 
171     std::vector<RetainedInputToken> RetainInput(std::vector<size_t> indices);
172 
173     // Flags an output array to be retained for use in the backward pass.
174     // Op implementations can use this function in combination with BackwardContext::GetRetainedOutput() to retrieve output arrays in the
175     // backward pass.
176     //
177     // If an op implementation requires the output array of the forward pass in the backward pass, it should call
178     // BackwardBuilder::RetainOutput() in the forward pass and keep its return value (either assign a variable or capture by
179     // value in a lambda expression). In the backward pass, it should call BackwardContext::GetRetainedOutput() with this token to retrieve
180     // the output array.
181     //
182     // Capturing the output array directly with lambda expression would cause cyclic reference and therefore would lead to memory leak.
183     //
184     // Reusing the token for higher-order backward functions results in undefined behavior.
185     //
186     // `output` must be one of the arrays specified in the constructor of BackwardBuilder as output arrays.
187     // If invalid array is specified, ChainerxError will be thrown.
188     RetainedOutputToken RetainOutput(size_t output_index);
189     std::vector<RetainedOutputToken> RetainOutput(std::vector<size_t> indices);
190 
191     // Finalizes the builder.
192     //
193     // This functions must be called when targets have been created for all inputs.
194     void Finalize();
195 
196 private:
197     // Create an op node for a specific graph.
198     // Edges from output nodes to the op node are connected.
199     std::shared_ptr<internal::OpNode>& FindOrCreateOpNode(const BackpropId& backprop_id);
200 
201     // Add shared ptrs between op nodes and array nodes belonging to outer graphs.
202     // This functions is called once when the builder is finalized.
203     // These references are required to restore retained inputs/outputs.
204     void AddEdgesFromOpNodeToArrayNodeOfOuterGraphsForRetention();
205 
206     void ConnectBackpropIds();
207 
208     const char* op_name_;
209 
210     Context& context_;
211 
212     // Input arrays of the op.
213     std::vector<ConstArrayRef> inputs_;
214 
215     // Flags indicating whether CreateTarget has been called for each of the input arrays.
216     // All of these flags must be true after all the backwards have been defined for a BackwardBuilder.
217     // This can be checked by calling is_complete();
218     std::vector<bool> inputs_target_created_;
219 
220     // Output arrays of the op.
221     std::vector<ConstArrayRef> outputs_;
222 
223     // A collection of op nodes, each of which corresponds to a graph.
224     // This record is increasingly populated as new graphs are encountered in multiple Define() calls.
225     absl::flat_hash_map<BackpropId, std::shared_ptr<internal::OpNode>> op_node_map_;
226 
227     backward_builder_detail::RetentionRecord input_retention_record_;
228     backward_builder_detail::RetentionRecord output_retention_record_;
229 
230     bool has_any_applicable_outputs_;
231     bool is_finalized_{false};
232 };
233 
234 }  // namespace chainerx
235