1 /* 2 * SPDX-License-Identifier: Apache-2.0 3 */ 4 5 // Adapter for Scatter in default domain from version 10 to 11 6 7 #pragma once 8 9 namespace ONNX_NAMESPACE { namespace version_conversion { 10 11 class Scatter_10_11 final : public Adapter { 12 public: Scatter_10_11()13 explicit Scatter_10_11() 14 : Adapter("Scatter", OpSetID(10), OpSetID(11)) {} 15 adapt_scatter_10_11(std::shared_ptr<Graph> graph,Node * node)16 Node* adapt_scatter_10_11(std::shared_ptr<Graph> graph, Node* node) const { 17 int axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 0; 18 19 // Replace the node with an equivalent ScatterElements node 20 Node* scatter_elements = graph->create(kScatterElements); 21 scatter_elements->i_(kaxis, axis); 22 scatter_elements->addInput(node->inputs()[0]); 23 scatter_elements->addInput(node->inputs()[1]); 24 scatter_elements->addInput(node->inputs()[2]); 25 node->replaceAllUsesWith(scatter_elements); 26 27 scatter_elements->insertBefore(node); 28 node->destroy(); 29 30 return scatter_elements; 31 } 32 adapt(std::shared_ptr<Graph> graph,Node * node)33 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override { 34 return adapt_scatter_10_11(graph, node); 35 } 36 }; 37 38 }} // namespace ONNX_NAMESPACE::version_conversion 39