1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 #include "onnx/version_converter/convert.h"
6 
7 namespace ONNX_NAMESPACE { namespace version_conversion {
8 
ConvertVersion(const ModelProto & mp_in,int target_version)9 ModelProto ConvertVersion(
10     const ModelProto& mp_in,
11     int target_version) {
12   // Get initial_opsetid from mp_in
13   OpSetID initial_struct(0);
14   for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) {
15     if (it->domain() == "" || it->domain() == "ai.onnx") {
16       initial_struct.setVersion(it->version());
17       break;
18     }
19   }
20   OpSetID target_struct = OpSetID(target_version);
21   DefaultVersionConverter v;
22   return v.convert_version(mp_in, initial_struct, target_struct);
23 }
24 
convert_graph(std::shared_ptr<Graph> g,const OpSetID & initial_version,const OpSetID & target_version) const25 void DefaultVersionConverter::convert_graph(
26   std::shared_ptr<Graph> g,
27   const OpSetID& initial_version,
28   const OpSetID& target_version
29   ) const {
30   assertNonNull(g);
31 
32   // TODO: Move to Inter-Domain Converter
33   // Get initial model versions
34   // std::vector<OpSetID> initial_versions = g->opset_versions_mutable();
35 
36   // No conversion necessary if Model has single, equivalent opset version
37   // if (initial_versions.size() == 1 && initial_versions[0].version ==
38   //    target_version.version && initial_versions[0].domain ==
39   //    target_version.domain) {
40   //  return mp_in;
41   // }
42 
43   // Check if versions are valid
44   assertInVersionRange(initial_version.version());
45   assertInVersionRange(target_version.version());
46 
47   // Iterate over all versions to target_version for specified
48   int64_t curr_version = initial_version.version();
49   int64_t step;
50   if (target_version.version() > initial_version.version()) {
51     step = 1;
52   } else {
53     step = -1;
54   }
55   // Identify index of this domain in g.opset_versions
56   unsigned int domain_index = 0;
57   for (unsigned int i = 0; i < g->opset_versions_mutable().size(); i++) {
58     if (g->opset_versions_mutable()[i].domain() == "") {
59       domain_index = i;
60     }
61   }
62   while (curr_version != target_version.version()) {
63     debug("curr_version: " + ONNX_NAMESPACE::to_string(curr_version) + ", next_version: " +
64         ONNX_NAMESPACE::to_string(curr_version + step));
65     Node *cur_op;
66     graph_node_list_iterator it = g->begin();
67     // Iterate through and call adapter returned by adapter_lookup for ops from
68     // current_version opset. We have to manipulate the iterator explicitly because cur_op
69     // might change when applying the adapter (e.g. for deprecated ops)
70     while ( it != g->end() ) {
71       cur_op = *it;
72       debug(std::string("Finding schema for ") + std::string(cur_op->kind().toString()));
73       const std::string op_name = cur_op->kind().toString();
74       if (op_name == "ConstantFill")
75       {
76         std::cerr << "Warning: skipping schema search for experimental op 'ConstantFill' and keeping the op as is. "
77         "Please be advised the converted model may not be working properly if target runtime does not support this "
78         "experimental op." << std::endl;
79         continue;
80       }
81       if (op_name != "Undefined" && op_name != "Captured") {
82         auto& op_domain_map = all_schemas.at(op_name);
83         OpSetID curr_id(curr_version);
84         OpSetID next_id(curr_version + step);
85         if (searchOpDomainMap(op_domain_map, curr_version, step)) {
86           // Op is specifically defined for this domain and version
87           auto& op_adapter = adapter_lookup(cur_op, curr_id, next_id);
88           // If adapter_lookup returns null, no adapter is present.
89           // Error thrown by adapter_lookup
90           if (DEBUG) std::cerr << "Applying adapter" << std::endl;
91           // adapt should handle replacing node in graph
92           cur_op = op_adapter.adapt(g, cur_op);
93           it = graph_node_list_iterator(cur_op, kNextDirection);
94         }
95         // Recursively convert any subgraph attributes
96         for (const auto& attr : cur_op->attributeNames()) {
97           if (cur_op->kindOf(attr) == AttributeKind::g) {
98             convert_graph(cur_op->g(attr), curr_id, next_id);
99           }
100         }
101       }
102       it++;
103     }
104     // Update model version
105     curr_version += step;
106     g->opset_versions_mutable()[domain_index].incrementVersion(step);
107   }
108 }
109 
convert_version(const ModelProto & mp_in,const OpSetID & initial_version,const OpSetID & target_version) const110 ModelProto DefaultVersionConverter::convert_version(
111     const ModelProto& mp_in,
112     const OpSetID& initial_version,
113     const OpSetID& target_version) const {
114   const std::string& initial_domain = initial_version.domain();
115   const std::string& target_domain = target_version.domain();
116   assertDefaultDomain(initial_domain, target_domain);
117 
118   for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import()
119       .end(); ++it) {
120     if (it->domain() == initial_version.domain()) {
121       ONNX_ASSERTM(initial_version.version() == it->version(),
122           "initial_version does not reflect current state of model");
123     }
124   }
125 
126   std::shared_ptr<Graph> g(ImportModelProto(mp_in));
127 
128   convert_graph(g, initial_version, target_version);
129 
130   // Export g as ModelProto
131   debug("Finished conversion; returning model");
132   ModelProto mp_out = PrepareOutput(mp_in);
133   ExportModelProto(&mp_out, g);
134   return mp_out;
135 }
136 
137 }} // namespace ONNX_NAMESPACE::version_conversion
138