1 #pragma once 2 3 #include <cstddef> 4 #include <cstdint> 5 #include <functional> 6 #include <initializer_list> 7 #include <memory> 8 #include <numeric> 9 #include <set> 10 #include <utility> 11 #include <vector> 12 13 #include <absl/container/flat_hash_map.h> 14 #include <gsl/gsl> 15 16 #include "chainerx/array.h" 17 #include "chainerx/array_body.h" 18 #include "chainerx/array_node.h" 19 #include "chainerx/constant.h" 20 #include "chainerx/device.h" 21 #include "chainerx/dtype.h" 22 #include "chainerx/graph.h" 23 #include "chainerx/macro.h" 24 #include "chainerx/op_node.h" 25 #include "chainerx/shape.h" 26 27 namespace chainerx { 28 namespace backward_builder_detail { 29 30 // This class is used by the BackwardBuilder to record retained inputs and outputs. 31 // The records are used to create outer graph edges (between op nodes and previous array nodes) when the builder is finalized. 32 class RetentionRecord { 33 public: RetentionRecord(size_t size)34 explicit RetentionRecord(size_t size) : size_{size} { CHAINERX_ASSERT(size_ > 0); } 35 size()36 size_t size() const { return size_; } 37 Record(size_t index)38 void Record(size_t index) { 39 if (flags_.empty()) { 40 flags_.resize(size_); 41 } 42 gsl::at(flags_, index) = static_cast<int8_t>(true); 43 } 44 IsAnyRecorded()45 bool IsAnyRecorded() const { return !flags_.empty(); } 46 IsRecorded(size_t index)47 bool IsRecorded(size_t index) const { return static_cast<bool>(flags_[index]); } 48 49 private: 50 size_t size_{}; 51 std::vector<int8_t> flags_{}; // binary flags 52 }; 53 54 template <typename Tag> 55 class RetainedArrayToken { 56 public: RetainedArrayToken(internal::ArrayBody::Params array_params,size_t index)57 RetainedArrayToken(internal::ArrayBody::Params array_params, size_t index) : array_params_{std::move(array_params)}, index_{index} {} 58 59 ~RetainedArrayToken() = default; 60 61 RetainedArrayToken(const RetainedArrayToken&) = default; 62 RetainedArrayToken(RetainedArrayToken&&) noexcept = default; 63 RetainedArrayToken& operator=(const RetainedArrayToken&) = default; 64 // TODO(hvy): Make the move assignment operator noexcept. 65 RetainedArrayToken& operator=(RetainedArrayToken&&) = default; // NOLINT(performance-noexcept-move-constructor) 66 67 private: 68 friend class chainerx::BackwardContext; 69 70 // Returns the array index. index()71 size_t index() const { return index_; } 72 array_params()73 const internal::ArrayBody::Params& array_params() const { return array_params_; } 74 75 internal::ArrayBody::Params array_params_; 76 77 size_t index_; 78 }; 79 80 } // namespace backward_builder_detail 81 82 // An object used by op implementations to bridge between BackwardBuilder::RetainInput() and BackwardContext::GetRetainedInput(). 83 // 84 // See BackwardBuilder::RetainInput() for details. 85 using RetainedInputToken = backward_builder_detail::RetainedArrayToken<struct InputTag>; 86 87 // An object used by op implementations to bridge between BackwardBuilder::RetainOutput() and BackwardContext::GetRetainedOutput(). 88 // 89 // See BackwardBuilder::RetainOutput() for details. 90 using RetainedOutputToken = backward_builder_detail::RetainedArrayToken<struct OutputTag>; 91 92 // A class that is used to define backward operations and connect the graph. 93 // 94 // This class is not thread safe. 95 class BackwardBuilder { 96 public: 97 // Target is responsible to define edges from OpNode to input ArrayNodes with given BackwardFunction. 98 // Note that Targets built from the same BackwardBuilder share some properties not to compute again. 99 class Target { 100 public: 101 explicit operator bool() const { return is_definition_required(); } 102 103 // Defines a backward function with respect to specified input arrays (target). 104 void Define(const BackwardFunction& backward_func); 105 is_definition_required()106 bool is_definition_required() const { return !graph_to_input_array_nodes_.empty(); } 107 108 private: 109 friend class BackwardBuilder; // Only BackwardBuilder can create Target 110 111 using InputArrayNodes = std::vector<const std::shared_ptr<internal::ArrayNode>*>; 112 113 Target(BackwardBuilder& builder, std::vector<size_t> input_indices); 114 115 // Collect input ArrayNodes, grouped by graph considering IsBackpropRequired. 116 // This functions is only called once in the constructor. 117 absl::flat_hash_map<BackpropId, InputArrayNodes> CreateInputArrayNodesMap() const; 118 119 BackwardBuilder& builder_; 120 std::vector<size_t> input_indices_; 121 122 // TODO(hvy): Consider using linear search since elements are usually few. 123 absl::flat_hash_map<BackpropId, InputArrayNodes> graph_to_input_array_nodes_; 124 }; 125 126 // TODO(niboshi): Add an overload to accept `const std::vector<Array>&` as `inputs` and `outputs` 127 // Note that simply overloading with the above type will results in ambiguous calls. 128 // One solution is to define a type that accepts all of the expected types of inputs. 129 BackwardBuilder(const char* op_name, std::vector<ConstArrayRef> inputs, std::vector<ConstArrayRef> outputs); BackwardBuilder(const char * op_name,const Array & input,std::vector<ConstArrayRef> outputs)130 BackwardBuilder(const char* op_name, const Array& input, std::vector<ConstArrayRef> outputs) 131 : BackwardBuilder{op_name, std::vector<ConstArrayRef>{input}, std::move(outputs)} {} BackwardBuilder(const char * op_name,std::vector<ConstArrayRef> inputs,const Array & output)132 BackwardBuilder(const char* op_name, std::vector<ConstArrayRef> inputs, const Array& output) 133 : BackwardBuilder{op_name, std::move(inputs), std::vector<ConstArrayRef>{output}} {} BackwardBuilder(const char * op_name,const Array & input,const Array & output)134 BackwardBuilder(const char* op_name, const Array& input, const Array& output) 135 : BackwardBuilder{op_name, std::vector<ConstArrayRef>{input}, std::vector<ConstArrayRef>{output}} {} ~BackwardBuilder()136 ~BackwardBuilder() { CHAINERX_ASSERT(is_finalized_); } 137 138 BackwardBuilder(const BackwardBuilder&) = delete; 139 BackwardBuilder(BackwardBuilder&&) noexcept = default; 140 BackwardBuilder& operator=(const BackwardBuilder&) = delete; 141 BackwardBuilder& operator=(BackwardBuilder&&) = delete; 142 143 // Creates a backward target for the specified inputs. CreateTarget(std::vector<size_t> input_indices)144 Target CreateTarget(std::vector<size_t> input_indices) { 145 // input_indices shouldn't have duplicates. 146 CHAINERX_ASSERT((std::set<size_t>{input_indices.begin(), input_indices.end()}.size() == input_indices.size())); 147 148 for (size_t input_index : input_indices) { 149 CHAINERX_ASSERT(input_index < inputs_target_created_.size()); 150 CHAINERX_ASSERT(!inputs_target_created_[input_index]); 151 inputs_target_created_[input_index] = true; 152 } 153 return Target{*this, std::move(input_indices)}; 154 } 155 156 // Creates a backward target for the specified input. CreateTarget(size_t input_index)157 Target CreateTarget(size_t input_index) { return CreateTarget(std::vector<size_t>{input_index}); } 158 159 // Creates a backward target for all the inputs. CreateTarget()160 Target CreateTarget() { 161 std::vector<size_t> input_indices; 162 input_indices.resize(inputs_.size()); 163 std::iota(input_indices.begin(), input_indices.end(), size_t{0}); 164 165 return CreateTarget(std::move(input_indices)); 166 } 167 168 // TODO(hvy): Write comment. 169 RetainedInputToken RetainInput(size_t input_index); 170 171 std::vector<RetainedInputToken> RetainInput(std::vector<size_t> indices); 172 173 // Flags an output array to be retained for use in the backward pass. 174 // Op implementations can use this function in combination with BackwardContext::GetRetainedOutput() to retrieve output arrays in the 175 // backward pass. 176 // 177 // If an op implementation requires the output array of the forward pass in the backward pass, it should call 178 // BackwardBuilder::RetainOutput() in the forward pass and keep its return value (either assign a variable or capture by 179 // value in a lambda expression). In the backward pass, it should call BackwardContext::GetRetainedOutput() with this token to retrieve 180 // the output array. 181 // 182 // Capturing the output array directly with lambda expression would cause cyclic reference and therefore would lead to memory leak. 183 // 184 // Reusing the token for higher-order backward functions results in undefined behavior. 185 // 186 // `output` must be one of the arrays specified in the constructor of BackwardBuilder as output arrays. 187 // If invalid array is specified, ChainerxError will be thrown. 188 RetainedOutputToken RetainOutput(size_t output_index); 189 std::vector<RetainedOutputToken> RetainOutput(std::vector<size_t> indices); 190 191 // Finalizes the builder. 192 // 193 // This functions must be called when targets have been created for all inputs. 194 void Finalize(); 195 196 private: 197 // Create an op node for a specific graph. 198 // Edges from output nodes to the op node are connected. 199 std::shared_ptr<internal::OpNode>& FindOrCreateOpNode(const BackpropId& backprop_id); 200 201 // Add shared ptrs between op nodes and array nodes belonging to outer graphs. 202 // This functions is called once when the builder is finalized. 203 // These references are required to restore retained inputs/outputs. 204 void AddEdgesFromOpNodeToArrayNodeOfOuterGraphsForRetention(); 205 206 void ConnectBackpropIds(); 207 208 const char* op_name_; 209 210 Context& context_; 211 212 // Input arrays of the op. 213 std::vector<ConstArrayRef> inputs_; 214 215 // Flags indicating whether CreateTarget has been called for each of the input arrays. 216 // All of these flags must be true after all the backwards have been defined for a BackwardBuilder. 217 // This can be checked by calling is_complete(); 218 std::vector<bool> inputs_target_created_; 219 220 // Output arrays of the op. 221 std::vector<ConstArrayRef> outputs_; 222 223 // A collection of op nodes, each of which corresponds to a graph. 224 // This record is increasingly populated as new graphs are encountered in multiple Define() calls. 225 absl::flat_hash_map<BackpropId, std::shared_ptr<internal::OpNode>> op_node_map_; 226 227 backward_builder_detail::RetentionRecord input_retention_record_; 228 backward_builder_detail::RetentionRecord output_retention_record_; 229 230 bool has_any_applicable_outputs_; 231 bool is_finalized_{false}; 232 }; 233 234 } // namespace chainerx 235