1 2from __future__ import unicode_literals 3 4import onnx 5 6import json 7import io 8import os 9import re 10import sys 11 12from onnx import defs 13from onnx.defs import OpSchema 14from onnx.backend.test.case import collect_snippets 15 16snippets = collect_snippets() 17 18categories = { 19 'Constant': 'Constant', 20 21 'Conv': 'Layer', 22 'ConvTranspose': 'Layer', 23 'FC': 'Layer', 24 'RNN': 'Layer', 25 'LSTM': 'Layer', 26 'GRU': 'Layer', 27 'Gemm': 'Layer', 28 29 'Dropout': 'Dropout', 30 31 'Elu': 'Activation', 32 'HardSigmoid': 'Activation', 33 'LeakyRelu': 'Activation', 34 'PRelu': 'Activation', 35 'ThresholdedRelu': 'Activation', 36 'Relu': 'Activation', 37 'Selu': 'Activation', 38 'Sigmoid': 'Activation', 39 'Tanh': 'Activation', 40 'LogSoftmax': 'Activation', 41 'Softmax': 'Activation', 42 'Softplus': 'Activation', 43 'Softsign': 'Activation', 44 45 'BatchNormalization': 'Normalization', 46 'InstanceNormalization': 'Normalization', 47 'LpNormalization': 'Normalization', 48 'LRN': 'Normalization', 49 50 'Flatten': 'Shape', 51 'Reshape': 'Shape', 52 'Tile': 'Shape', 53 54 'Xor': 'Logic', 55 'Not': 'Logic', 56 'Or': 'Logic', 57 'Less': 'Logic', 58 'And': 'Logic', 59 'Greater': 'Logic', 60 'Equal': 'Logic', 61 62 'AveragePool': 'Pool', 63 'GlobalAveragePool': 'Pool', 64 'GlobalLpPool': 'Pool', 65 'GlobalMaxPool': 'Pool', 66 'LpPool': 'Pool', 67 'MaxPool': 'Pool', 68 'MaxRoiPool': 'Pool', 69 70 'Concat': 'Tensor', 71 'Slice': 'Tensor', 72 'Split': 'Tensor', 73 'Pad': 'Tensor', 74 75 'ImageScaler': 'Data', 76 'Crop': 'Data', 77 'Upsample': 'Data', 78 79 'Transpose': 'Transform', 80 'Gather': 'Transform', 81 'Unsqueeze': 'Transform', 82 'Squeeze': 'Transform', 83} 84 85attribute_type_table = { 86 'undefined': None, 87 'float': 'float32', 'int': 'int64', 'string': 'string', 'tensor': 'tensor', 'graph': 'graph', 88 'floats': 'float32[]', 'ints': 'int64[]', 'strings': 'string[]', 'tensors': 'tensor[]', 'graphs': 'graph[]', 89} 90 91def generate_json_attr_type(type): 92 assert isinstance(type, OpSchema.AttrType) 93 s = str(type) 94 s = s[s.rfind('.')+1:].lower() 95 if s in attribute_type_table: 96 return attribute_type_table[s] 97 return None 98 99def generate_json_attr_default_value(attr_value): 100 if not str(attr_value): 101 return None 102 if attr_value.HasField('i'): 103 return attr_value.i 104 if attr_value.HasField('s'): 105 return attr_value.s.decode('utf8') 106 if attr_value.HasField('f'): 107 return attr_value.f 108 return None 109 110def generate_json_support_level_name(support_level): 111 assert isinstance(support_level, OpSchema.SupportType) 112 s = str(support_level) 113 return s[s.rfind('.')+1:].lower() 114 115def generate_json_types(types): 116 r = [] 117 for type in types: 118 r.append(type) 119 r = sorted(r) 120 return r 121 122def format_range(value): 123 if value == 2147483647: 124 return '∞' 125 return str(value) 126 127def format_description(description): 128 def replace_line(match): 129 link = match.group(1) 130 url = match.group(2) 131 if not url.startswith("http://") and not url.startswith("https://"): 132 url = "https://github.com/onnx/onnx/blob/master/docs/" + url 133 return "[" + link + "](" + url + ")"; 134 description = re.sub("\\[(.+)\\]\\(([^ ]+?)( \"(.+)\")?\\)", replace_line, description) 135 return description 136 137def generate_json(schemas, json_file): 138 json_root = [] 139 for schema in schemas: 140 json_schema = {} 141 if schema.domain: 142 json_schema['domain'] = schema.domain 143 else: 144 json_schema['domain'] = 'ai.onnx' 145 json_schema['since_version'] = schema.since_version 146 json_schema['support_level'] = generate_json_support_level_name(schema.support_level) 147 if schema.doc: 148 json_schema['description'] = format_description(schema.doc.lstrip()) 149 if schema.inputs: 150 json_schema['inputs'] = [] 151 for input in schema.inputs: 152 json_input = {} 153 json_input['name'] = input.name 154 json_input['description'] = format_description(input.description) 155 json_input['type'] = input.typeStr 156 if input.option == OpSchema.FormalParameterOption.Optional: 157 json_input['option'] = 'optional' 158 elif input.option == OpSchema.FormalParameterOption.Variadic: 159 json_input['option'] = 'variadic' 160 json_schema['inputs'].append(json_input) 161 json_schema['min_input'] = schema.min_input 162 json_schema['max_input'] = schema.max_input 163 if schema.outputs: 164 json_schema['outputs'] = [] 165 for output in schema.outputs: 166 json_output = {} 167 json_output['name'] = output.name 168 json_output['description'] = format_description(output.description) 169 json_output['type'] = output.typeStr 170 if output.option == OpSchema.FormalParameterOption.Optional: 171 json_output['option'] = 'optional' 172 elif output.option == OpSchema.FormalParameterOption.Variadic: 173 json_output['option'] = 'variadic' 174 json_schema['outputs'].append(json_output) 175 json_schema['min_output'] = schema.min_output 176 json_schema['max_output'] = schema.max_output 177 if schema.min_input != schema.max_input: 178 json_schema['inputs_range'] = format_range(schema.min_input) + ' - ' + format_range(schema.max_input); 179 if schema.min_output != schema.max_output: 180 json_schema['outputs_range'] = format_range(schema.min_output) + ' - ' + format_range(schema.max_output); 181 if schema.attributes: 182 json_schema['attributes'] = [] 183 for _, attribute in sorted(schema.attributes.items()): 184 json_attribute = {} 185 json_attribute['name'] = attribute.name 186 json_attribute['description'] = format_description(attribute.description) 187 attribute_type = generate_json_attr_type(attribute.type) 188 if attribute_type: 189 json_attribute['type'] = attribute_type 190 elif 'type' in json_attribute: 191 del json_attribute['type'] 192 json_attribute['required'] = attribute.required 193 default_value = generate_json_attr_default_value(attribute.default_value) 194 if default_value: 195 json_attribute['default'] = default_value 196 json_schema['attributes'].append(json_attribute) 197 if schema.type_constraints: 198 json_schema['type_constraints'] = [] 199 for type_constraint in schema.type_constraints: 200 json_schema['type_constraints'].append({ 201 'description': type_constraint.description, 202 'type_param_str': type_constraint.type_param_str, 203 'allowed_type_strs': type_constraint.allowed_type_strs 204 }) 205 if schema.name in snippets: 206 json_schema['examples'] = [] 207 for summary, code in sorted(snippets[schema.name]): 208 json_schema['examples'].append({ 209 'summary': summary, 210 'code': code 211 }) 212 if schema.name in categories: 213 json_schema['category'] = categories[schema.name] 214 json_root.append({ 215 'name': schema.name, 216 'schema': json_schema 217 }) 218 with io.open(json_file, 'w', newline='') as fout: 219 json_root = json.dumps(json_root, sort_keys=True, indent=2) 220 for line in json_root.splitlines(): 221 line = line.rstrip() 222 if sys.version_info[0] < 3: 223 line = unicode(line) 224 fout.write(line) 225 fout.write('\n') 226 227def metadata(): 228 schemas = defs.get_all_schemas_with_history() 229 schemas = sorted(schemas, key=lambda schema: schema.name) 230 json_file = os.path.join(os.path.dirname(__file__), '../src/onnx-metadata.json') 231 generate_json(schemas, json_file) 232 233def convert(): 234 def pip_import(package): 235 import importlib 236 try: 237 importlib.import_module(package) 238 except: 239 import subprocess 240 subprocess.call([ 'pip', 'install', '--quiet', package ]) 241 finally: 242 globals()[package] = importlib.import_module(package) 243 file = sys.argv[2] 244 base, extension = os.path.splitext(file) 245 if extension == '.mlmodel': 246 pip_import('coremltools') 247 import onnxmltools 248 coreml_model = coremltools.utils.load_spec(file) 249 onnx_model = onnxmltools.convert.convert_coreml(coreml_model) 250 onnxmltools.utils.save_model(onnx_model, base + '.onnx') 251 elif extension == '.h5': 252 pip_import('tensorflow') 253 pip_import('keras') 254 import onnxmltools 255 keras_model = keras.models.load_model(file) 256 onnx_model = onnxmltools.convert.convert_keras(keras_model) 257 onnxmltools.utils.save_model(onnx_model, base + '.onnx') 258 elif extension == '.pkl': 259 pip_import('sklearn') 260 import onnxmltools 261 sklearn_model = sklearn.externals.joblib.load(file) 262 onnx_model = onnxmltools.convert.convert_sklearn(sklearn_model) 263 onnxmltools.utils.save_model(onnx_model, base + '.onnx') 264 base, extension = os.path.splitext(file) 265 if extension == '.onnx': 266 import onnx 267 from google.protobuf import text_format 268 onnx_model = onnx.load(file) 269 text = text_format.MessageToString(onnx_model) 270 with open(base + '.pbtxt', 'w') as text_file: 271 text_file.write(text) 272 273def optimize(): 274 import onnx 275 from onnx import optimizer 276 file = sys.argv[2] 277 base = os.path.splitext(file) 278 onnx_model = onnx.load(file) 279 passes = optimizer.get_available_passes() 280 optimized_model = optimizer.optimize(onnx_model, passes) 281 onnx.save(optimized_model, base + '.optimized.onnx') 282 283def infer(): 284 import onnx 285 import onnx.shape_inference 286 from onnx import shape_inference 287 file = sys.argv[2] 288 base = os.path.splitext(file)[0] 289 onnx_model = onnx.load(base + '.onnx') 290 onnx_model = onnx.shape_inference.infer_shapes(onnx_model) 291 onnx.save(onnx_model, base + '.shape.onnx') 292 293if __name__ == '__main__': 294 command_table = { 'metadata': metadata, 'convert': convert, 'optimize': optimize, 'infer': infer } 295 command = sys.argv[1] 296 command_table[command]() 297