1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 // Adapter for broadcasting ops in default domain from version 6 to 7
6 
7 #pragma once
8 
9 #include "onnx/version_converter/adapters/adapter.h"
10 
11 namespace ONNX_NAMESPACE { namespace version_conversion {
12 
13 class BroadcastForwardCompatibility final : public Adapter {
14   public:
BroadcastForwardCompatibility(const std::string & op_name,const OpSetID & initial,const OpSetID & target)15     explicit BroadcastForwardCompatibility(const std::string& op_name, const OpSetID&
16       initial, const OpSetID& target): Adapter(op_name, initial, target) {}
17 
adapt_broadcast_forward_compatibility(std::shared_ptr<Graph> graph,Node * node)18     void adapt_broadcast_forward_compatibility(std::shared_ptr<Graph> graph, Node* node)
19       const {
20       // Remove axis and broadcast attributes
21       // Assess whether axis requires reshaping
22       if (node->hasAttribute(kbroadcast)) {
23         const ArrayRef<Value*>& inputs = node->inputs();
24         assertInputsAvailable(inputs, name().c_str(), 2);
25         const std::vector<Dimension>& A_sizes = inputs[0]->sizes();
26         const std::vector<Dimension>& B_sizes = inputs[1]->sizes();
27         // Also assert that broadcasting syntax are correct if axis is not present
28         if (node->hasAttribute(kaxis)) {
29           if (node->i(kaxis) != (int) (A_sizes.size() - B_sizes.size())) {
30             // Add a Reshape node before input B
31             Node * n = graph->create(kUnsqueeze);
32             n->addInput(inputs[1]);
33             std::vector<int64_t> axes;
34             std::vector<Dimension> new_sizes = B_sizes;
35             auto size = A_sizes.size() > B_sizes.size() ? A_sizes.size() - B_sizes.size() : 0;
36             axes.reserve(size);
37             new_sizes.reserve(new_sizes.size() + size);
38             for (size_t i = 0; i < size; i++) {
39               axes.emplace_back(B_sizes.size() + i);
40               new_sizes.emplace_back(Dimension(1));
41             }
42             if (target_version().version() >= 13){ //Unsqueeze takes 'axes' input
43               Tensor t;
44               t.elem_type() = TensorProto_DataType_INT64;
45               t.sizes() = std::vector<int64_t>{static_cast<int64_t>(axes.size())};
46               auto& data = t.int64s();
47               for (auto a : axes) {
48                 data.emplace_back(a);
49               }
50               Node* constant = graph->create(kConstant);
51               constant->insertBefore(node);
52               constant->t_(kvalue, t);
53               node->addInput(constant->output());
54             } else { // Unsqueeze takes 'axes' attribute
55               n->is_(kaxes, std::forward<const std::vector<int64_t>>(axes));
56             }
57             // Move n before node
58             n->insertBefore(node);
59             // Set 2nd input to node to 1st of n and output of n to 2nd input to node
60             n->output()->setSizes(new_sizes);
61             node->replaceInput(1, n->output());
62           }
63         }
64         node->removeAttribute(kbroadcast);
65       }
66       if (node->hasAttribute(kaxis)) node->removeAttribute(kaxis);
67       // Assert multi_broadcastable on inputs
68       const ArrayRef<Value*>& inputs = node->inputs();
69       assert_numpy_multibroadcastable(inputs[0]->sizes(), inputs[1]->sizes());
70     }
71 
adapt(std::shared_ptr<Graph> graph,Node * node)72     Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
73       adapt_broadcast_forward_compatibility(graph, node);
74       return node;
75     }
76 };
77 
78 }} // namespace ONNX_NAMESPACE::version_conversion
79