1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5 #include "onnx/version_converter/convert.h"
6
7 namespace ONNX_NAMESPACE { namespace version_conversion {
8
ConvertVersion(const ModelProto & mp_in,int target_version)9 ModelProto ConvertVersion(
10 const ModelProto& mp_in,
11 int target_version) {
12 // Get initial_opsetid from mp_in
13 OpSetID initial_struct(0);
14 for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) {
15 if (it->domain() == "" || it->domain() == "ai.onnx") {
16 initial_struct.setVersion(it->version());
17 break;
18 }
19 }
20 OpSetID target_struct = OpSetID(target_version);
21 DefaultVersionConverter v;
22 return v.convert_version(mp_in, initial_struct, target_struct);
23 }
24
convert_graph(std::shared_ptr<Graph> g,const OpSetID & initial_version,const OpSetID & target_version) const25 void DefaultVersionConverter::convert_graph(
26 std::shared_ptr<Graph> g,
27 const OpSetID& initial_version,
28 const OpSetID& target_version
29 ) const {
30 assertNonNull(g);
31
32 // TODO: Move to Inter-Domain Converter
33 // Get initial model versions
34 // std::vector<OpSetID> initial_versions = g->opset_versions_mutable();
35
36 // No conversion necessary if Model has single, equivalent opset version
37 // if (initial_versions.size() == 1 && initial_versions[0].version ==
38 // target_version.version && initial_versions[0].domain ==
39 // target_version.domain) {
40 // return mp_in;
41 // }
42
43 // Check if versions are valid
44 assertInVersionRange(initial_version.version());
45 assertInVersionRange(target_version.version());
46
47 // Iterate over all versions to target_version for specified
48 int64_t curr_version = initial_version.version();
49 int64_t step;
50 if (target_version.version() > initial_version.version()) {
51 step = 1;
52 } else {
53 step = -1;
54 }
55 // Identify index of this domain in g.opset_versions
56 unsigned int domain_index = 0;
57 for (unsigned int i = 0; i < g->opset_versions_mutable().size(); i++) {
58 if (g->opset_versions_mutable()[i].domain() == "") {
59 domain_index = i;
60 }
61 }
62 while (curr_version != target_version.version()) {
63 debug("curr_version: " + ONNX_NAMESPACE::to_string(curr_version) + ", next_version: " +
64 ONNX_NAMESPACE::to_string(curr_version + step));
65 Node *cur_op;
66 graph_node_list_iterator it = g->begin();
67 // Iterate through and call adapter returned by adapter_lookup for ops from
68 // current_version opset. We have to manipulate the iterator explicitly because cur_op
69 // might change when applying the adapter (e.g. for deprecated ops)
70 while ( it != g->end() ) {
71 cur_op = *it;
72 debug(std::string("Finding schema for ") + std::string(cur_op->kind().toString()));
73 const std::string op_name = cur_op->kind().toString();
74 if (op_name == "ConstantFill")
75 {
76 std::cerr << "Warning: skipping schema search for experimental op 'ConstantFill' and keeping the op as is. "
77 "Please be advised the converted model may not be working properly if target runtime does not support this "
78 "experimental op." << std::endl;
79 continue;
80 }
81 if (op_name != "Undefined" && op_name != "Captured") {
82 auto& op_domain_map = all_schemas.at(op_name);
83 OpSetID curr_id(curr_version);
84 OpSetID next_id(curr_version + step);
85 if (searchOpDomainMap(op_domain_map, curr_version, step)) {
86 // Op is specifically defined for this domain and version
87 auto& op_adapter = adapter_lookup(cur_op, curr_id, next_id);
88 // If adapter_lookup returns null, no adapter is present.
89 // Error thrown by adapter_lookup
90 if (DEBUG) std::cerr << "Applying adapter" << std::endl;
91 // adapt should handle replacing node in graph
92 cur_op = op_adapter.adapt(g, cur_op);
93 it = graph_node_list_iterator(cur_op, kNextDirection);
94 }
95 // Recursively convert any subgraph attributes
96 for (const auto& attr : cur_op->attributeNames()) {
97 if (cur_op->kindOf(attr) == AttributeKind::g) {
98 convert_graph(cur_op->g(attr), curr_id, next_id);
99 }
100 }
101 }
102 it++;
103 }
104 // Update model version
105 curr_version += step;
106 g->opset_versions_mutable()[domain_index].incrementVersion(step);
107 }
108 }
109
convert_version(const ModelProto & mp_in,const OpSetID & initial_version,const OpSetID & target_version) const110 ModelProto DefaultVersionConverter::convert_version(
111 const ModelProto& mp_in,
112 const OpSetID& initial_version,
113 const OpSetID& target_version) const {
114 const std::string& initial_domain = initial_version.domain();
115 const std::string& target_domain = target_version.domain();
116 assertDefaultDomain(initial_domain, target_domain);
117
118 for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import()
119 .end(); ++it) {
120 if (it->domain() == initial_version.domain()) {
121 ONNX_ASSERTM(initial_version.version() == it->version(),
122 "initial_version does not reflect current state of model");
123 }
124 }
125
126 std::shared_ptr<Graph> g(ImportModelProto(mp_in));
127
128 convert_graph(g, initial_version, target_version);
129
130 // Export g as ModelProto
131 debug("Finished conversion; returning model");
132 ModelProto mp_out = PrepareOutput(mp_in);
133 ExportModelProto(&mp_out, g);
134 return mp_out;
135 }
136
137 }} // namespace ONNX_NAMESPACE::version_conversion
138