1
2from __future__ import unicode_literals
3
4import onnx
5
6import json
7import io
8import os
9import re
10import sys
11
12from onnx import defs
13from onnx.defs import OpSchema
14from onnx.backend.test.case import collect_snippets
15
16snippets = collect_snippets()
17
18categories = {
19    'Constant': 'Constant',
20
21    'Conv': 'Layer',
22    'ConvTranspose': 'Layer',
23    'FC': 'Layer',
24    'RNN': 'Layer',
25    'LSTM': 'Layer',
26    'GRU': 'Layer',
27    'Gemm': 'Layer',
28
29    'Dropout': 'Dropout',
30
31    'Elu': 'Activation',
32    'HardSigmoid': 'Activation',
33    'LeakyRelu': 'Activation',
34    'PRelu': 'Activation',
35    'ThresholdedRelu': 'Activation',
36    'Relu': 'Activation',
37    'Selu': 'Activation',
38    'Sigmoid': 'Activation',
39    'Tanh': 'Activation',
40    'LogSoftmax': 'Activation',
41    'Softmax': 'Activation',
42    'Softplus': 'Activation',
43    'Softsign': 'Activation',
44
45    'BatchNormalization': 'Normalization',
46    'InstanceNormalization': 'Normalization',
47    'LpNormalization': 'Normalization',
48    'LRN': 'Normalization',
49
50    'Flatten': 'Shape',
51    'Reshape': 'Shape',
52    'Tile': 'Shape',
53
54    'Xor': 'Logic',
55    'Not': 'Logic',
56    'Or': 'Logic',
57    'Less': 'Logic',
58    'And': 'Logic',
59    'Greater': 'Logic',
60    'Equal': 'Logic',
61
62    'AveragePool': 'Pool',
63    'GlobalAveragePool': 'Pool',
64    'GlobalLpPool': 'Pool',
65    'GlobalMaxPool': 'Pool',
66    'LpPool': 'Pool',
67    'MaxPool': 'Pool',
68    'MaxRoiPool': 'Pool',
69
70    'Concat': 'Tensor',
71    'Slice': 'Tensor',
72    'Split': 'Tensor',
73    'Pad': 'Tensor',
74
75    'ImageScaler': 'Data',
76    'Crop': 'Data',
77    'Upsample': 'Data',
78
79    'Transpose': 'Transform',
80    'Gather': 'Transform',
81    'Unsqueeze': 'Transform',
82    'Squeeze': 'Transform',
83}
84
85attribute_type_table = {
86    'undefined': None,
87    'float': 'float32', 'int': 'int64', 'string': 'string', 'tensor': 'tensor', 'graph': 'graph',
88    'floats': 'float32[]', 'ints': 'int64[]', 'strings': 'string[]', 'tensors': 'tensor[]', 'graphs': 'graph[]',
89}
90
91def generate_json_attr_type(type):
92    assert isinstance(type, OpSchema.AttrType)
93    s = str(type)
94    s = s[s.rfind('.')+1:].lower()
95    if s in attribute_type_table:
96        return attribute_type_table[s]
97    return None
98
99def generate_json_attr_default_value(attr_value):
100    if not str(attr_value):
101        return None
102    if attr_value.HasField('i'):
103        return attr_value.i
104    if attr_value.HasField('s'):
105        return attr_value.s.decode('utf8')
106    if attr_value.HasField('f'):
107        return attr_value.f
108    return None
109
110def generate_json_support_level_name(support_level):
111    assert isinstance(support_level, OpSchema.SupportType)
112    s = str(support_level)
113    return s[s.rfind('.')+1:].lower()
114
115def generate_json_types(types):
116    r = []
117    for type in types:
118        r.append(type)
119    r = sorted(r)
120    return r
121
122def format_range(value):
123    if value == 2147483647:
124        return '∞'
125    return str(value)
126
127def format_description(description):
128    def replace_line(match):
129        link = match.group(1)
130        url = match.group(2)
131        if not url.startswith("http://") and not url.startswith("https://"):
132            url = "https://github.com/onnx/onnx/blob/master/docs/" + url
133        return "[" + link + "](" + url + ")";
134    description = re.sub("\\[(.+)\\]\\(([^ ]+?)( \"(.+)\")?\\)", replace_line, description)
135    return description
136
137def generate_json(schemas, json_file):
138    json_root = []
139    for schema in schemas:
140        json_schema = {}
141        if schema.domain:
142            json_schema['domain'] = schema.domain
143        else:
144            json_schema['domain'] = 'ai.onnx'
145        json_schema['since_version'] = schema.since_version
146        json_schema['support_level'] = generate_json_support_level_name(schema.support_level)
147        if schema.doc:
148            json_schema['description'] = format_description(schema.doc.lstrip())
149        if schema.inputs:
150            json_schema['inputs'] = []
151            for input in schema.inputs:
152                json_input = {}
153                json_input['name'] = input.name
154                json_input['description'] = format_description(input.description)
155                json_input['type'] = input.typeStr
156                if input.option == OpSchema.FormalParameterOption.Optional:
157                    json_input['option'] = 'optional'
158                elif input.option == OpSchema.FormalParameterOption.Variadic:
159                    json_input['option'] = 'variadic'
160                json_schema['inputs'].append(json_input)
161        json_schema['min_input'] = schema.min_input
162        json_schema['max_input'] = schema.max_input
163        if schema.outputs:
164            json_schema['outputs'] = []
165            for output in schema.outputs:
166                json_output = {}
167                json_output['name'] = output.name
168                json_output['description'] = format_description(output.description)
169                json_output['type'] = output.typeStr
170                if output.option == OpSchema.FormalParameterOption.Optional:
171                    json_output['option'] = 'optional'
172                elif output.option == OpSchema.FormalParameterOption.Variadic:
173                    json_output['option'] = 'variadic'
174                json_schema['outputs'].append(json_output)
175        json_schema['min_output'] = schema.min_output
176        json_schema['max_output'] = schema.max_output
177        if schema.min_input != schema.max_input:
178            json_schema['inputs_range'] = format_range(schema.min_input) + ' - ' + format_range(schema.max_input);
179        if schema.min_output != schema.max_output:
180            json_schema['outputs_range'] = format_range(schema.min_output) + ' - ' + format_range(schema.max_output);
181        if schema.attributes:
182            json_schema['attributes'] = []
183            for _, attribute in sorted(schema.attributes.items()):
184                json_attribute = {}
185                json_attribute['name'] = attribute.name
186                json_attribute['description'] = format_description(attribute.description)
187                attribute_type = generate_json_attr_type(attribute.type)
188                if attribute_type:
189                    json_attribute['type'] = attribute_type
190                elif 'type' in json_attribute:
191                    del json_attribute['type']
192                json_attribute['required'] = attribute.required
193                default_value = generate_json_attr_default_value(attribute.default_value)
194                if default_value:
195                    json_attribute['default'] = default_value
196                json_schema['attributes'].append(json_attribute)
197        if schema.type_constraints:
198            json_schema['type_constraints'] = []
199            for type_constraint in schema.type_constraints:
200                json_schema['type_constraints'].append({
201                    'description': type_constraint.description,
202                    'type_param_str': type_constraint.type_param_str,
203                    'allowed_type_strs': type_constraint.allowed_type_strs
204                })
205        if schema.name in snippets:
206            json_schema['examples'] = []
207            for summary, code in sorted(snippets[schema.name]):
208                json_schema['examples'].append({
209                    'summary': summary,
210                    'code': code
211                })
212        if schema.name in categories:
213            json_schema['category'] = categories[schema.name]
214        json_root.append({
215            'name': schema.name,
216            'schema': json_schema
217        })
218    with io.open(json_file, 'w', newline='') as fout:
219        json_root = json.dumps(json_root, sort_keys=True, indent=2)
220        for line in json_root.splitlines():
221            line = line.rstrip()
222            if sys.version_info[0] < 3:
223                line = unicode(line)
224            fout.write(line)
225            fout.write('\n')
226
227def metadata():
228    schemas = defs.get_all_schemas_with_history()
229    schemas = sorted(schemas, key=lambda schema: schema.name)
230    json_file = os.path.join(os.path.dirname(__file__), '../src/onnx-metadata.json')
231    generate_json(schemas, json_file)
232
233def convert():
234    def pip_import(package):
235        import importlib
236        try:
237            importlib.import_module(package)
238        except:
239            import subprocess
240            subprocess.call([ 'pip', 'install', '--quiet', package ])
241        finally:
242            globals()[package] = importlib.import_module(package)
243    file = sys.argv[2]
244    base, extension = os.path.splitext(file)
245    if extension == '.mlmodel':
246        pip_import('coremltools')
247        import onnxmltools
248        coreml_model = coremltools.utils.load_spec(file)
249        onnx_model = onnxmltools.convert.convert_coreml(coreml_model)
250        onnxmltools.utils.save_model(onnx_model, base + '.onnx')
251    elif extension == '.h5':
252        pip_import('tensorflow')
253        pip_import('keras')
254        import onnxmltools
255        keras_model = keras.models.load_model(file)
256        onnx_model = onnxmltools.convert.convert_keras(keras_model)
257        onnxmltools.utils.save_model(onnx_model, base + '.onnx')
258    elif extension == '.pkl':
259        pip_import('sklearn')
260        import onnxmltools
261        sklearn_model = sklearn.externals.joblib.load(file)
262        onnx_model = onnxmltools.convert.convert_sklearn(sklearn_model)
263        onnxmltools.utils.save_model(onnx_model, base + '.onnx')
264    base, extension = os.path.splitext(file)
265    if extension == '.onnx':
266        import onnx
267        from google.protobuf import text_format
268        onnx_model = onnx.load(file)
269        text = text_format.MessageToString(onnx_model)
270        with open(base + '.pbtxt', 'w') as text_file:
271            text_file.write(text)
272
273def optimize():
274    import onnx
275    from onnx import optimizer
276    file = sys.argv[2]
277    base = os.path.splitext(file)
278    onnx_model = onnx.load(file)
279    passes = optimizer.get_available_passes()
280    optimized_model = optimizer.optimize(onnx_model, passes)
281    onnx.save(optimized_model, base + '.optimized.onnx')
282
283def infer():
284    import onnx
285    import onnx.shape_inference
286    from onnx import shape_inference
287    file = sys.argv[2]
288    base = os.path.splitext(file)[0]
289    onnx_model = onnx.load(base + '.onnx')
290    onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
291    onnx.save(onnx_model, base + '.shape.onnx')
292
293if __name__ == '__main__':
294    command_table = { 'metadata': metadata, 'convert': convert, 'optimize': optimize, 'infer': infer }
295    command = sys.argv[1]
296    command_table[command]()
297