1 /* 2 * SPDX-License-Identifier: Apache-2.0 3 */ 4 5 // Adapter for broadcasting ops in default domain from version 6 to 7 6 7 #pragma once 8 9 #include "onnx/version_converter/adapters/adapter.h" 10 11 namespace ONNX_NAMESPACE { namespace version_conversion { 12 13 class BroadcastForwardCompatibility final : public Adapter { 14 public: BroadcastForwardCompatibility(const std::string & op_name,const OpSetID & initial,const OpSetID & target)15 explicit BroadcastForwardCompatibility(const std::string& op_name, const OpSetID& 16 initial, const OpSetID& target): Adapter(op_name, initial, target) {} 17 adapt_broadcast_forward_compatibility(std::shared_ptr<Graph> graph,Node * node)18 void adapt_broadcast_forward_compatibility(std::shared_ptr<Graph> graph, Node* node) 19 const { 20 // Remove axis and broadcast attributes 21 // Assess whether axis requires reshaping 22 if (node->hasAttribute(kbroadcast)) { 23 const ArrayRef<Value*>& inputs = node->inputs(); 24 assertInputsAvailable(inputs, name().c_str(), 2); 25 const std::vector<Dimension>& A_sizes = inputs[0]->sizes(); 26 const std::vector<Dimension>& B_sizes = inputs[1]->sizes(); 27 // Also assert that broadcasting syntax are correct if axis is not present 28 if (node->hasAttribute(kaxis)) { 29 if (node->i(kaxis) != (int) (A_sizes.size() - B_sizes.size())) { 30 // Add a Reshape node before input B 31 Node * n = graph->create(kUnsqueeze); 32 n->addInput(inputs[1]); 33 std::vector<int64_t> axes; 34 std::vector<Dimension> new_sizes = B_sizes; 35 auto size = A_sizes.size() > B_sizes.size() ? A_sizes.size() - B_sizes.size() : 0; 36 axes.reserve(size); 37 new_sizes.reserve(new_sizes.size() + size); 38 for (size_t i = 0; i < size; i++) { 39 axes.emplace_back(B_sizes.size() + i); 40 new_sizes.emplace_back(Dimension(1)); 41 } 42 if (target_version().version() >= 13){ //Unsqueeze takes 'axes' input 43 Tensor t; 44 t.elem_type() = TensorProto_DataType_INT64; 45 t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())}; 46 auto& data = t.int64s(); 47 for (auto a : axes) { 48 data.emplace_back(a); 49 } 50 Node* constant = graph->create(kConstant); 51 constant->insertBefore(node); 52 constant->t_(kvalue, t); 53 node->addInput(constant->output()); 54 } else { // Unsqueeze takes 'axes' attribute 55 n->is_(kaxes, std::forward<const std::vector<int64_t>>(axes)); 56 } 57 // Move n before node 58 n->insertBefore(node); 59 // Set 2nd input to node to 1st of n and output of n to 2nd input to node 60 n->output()->setSizes(new_sizes); 61 node->replaceInput(1, n->output()); 62 } 63 } 64 node->removeAttribute(kbroadcast); 65 } 66 if (node->hasAttribute(kaxis)) node->removeAttribute(kaxis); 67 // Assert multi_broadcastable on inputs 68 const ArrayRef<Value*>& inputs = node->inputs(); 69 assert_numpy_multibroadcastable(inputs[0]->sizes(), inputs[1]->sizes()); 70 } 71 adapt(std::shared_ptr<Graph> graph,Node * node)72 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override { 73 adapt_broadcast_forward_compatibility(graph, node); 74 return node; 75 } 76 }; 77 78 }} // namespace ONNX_NAMESPACE::version_conversion 79