1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 // Adapter for Scatter in default domain from version 10 to 11
6 
7 #pragma once
8 
9 namespace ONNX_NAMESPACE { namespace version_conversion {
10 
11 class Scatter_10_11 final : public Adapter {
12   public:
Scatter_10_11()13     explicit Scatter_10_11()
14       : Adapter("Scatter", OpSetID(10), OpSetID(11)) {}
15 
adapt_scatter_10_11(std::shared_ptr<Graph> graph,Node * node)16     Node* adapt_scatter_10_11(std::shared_ptr<Graph> graph, Node* node) const {
17       int axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 0;
18 
19       // Replace the node with an equivalent ScatterElements node
20       Node* scatter_elements = graph->create(kScatterElements);
21       scatter_elements->i_(kaxis, axis);
22       scatter_elements->addInput(node->inputs()[0]);
23       scatter_elements->addInput(node->inputs()[1]);
24       scatter_elements->addInput(node->inputs()[2]);
25       node->replaceAllUsesWith(scatter_elements);
26 
27       scatter_elements->insertBefore(node);
28       node->destroy();
29 
30       return scatter_elements;
31     }
32 
adapt(std::shared_ptr<Graph> graph,Node * node)33     Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
34       return adapt_scatter_10_11(graph, node);
35     }
36 };
37 
38 }} // namespace ONNX_NAMESPACE::version_conversion
39