1# SPDX-License-Identifier: Apache-2.0 2 3from __future__ import absolute_import 4from __future__ import division 5from __future__ import print_function 6from __future__ import unicode_literals 7 8import os 9 10from .onnx_cpp2py_export import ONNX_ML 11from onnx.external_data_helper import load_external_data_for_model, write_external_data_tensors, convert_model_to_external_data 12from .onnx_pb import * # noqa 13from .onnx_operators_pb import * # noqa 14from .onnx_data_pb import * # noqa 15from .version import version as __version__ # noqa 16 17# Import common subpackages so they're available when you 'import onnx' 18import onnx.checker # noqa 19import onnx.defs # noqa 20import onnx.helper # noqa 21import onnx.utils # noqa 22 23import google.protobuf.message 24 25from typing import Union, Text, IO, Optional, cast, TypeVar, Any 26from six import string_types 27 28 29# f should be either readable or a file path 30def _load_bytes(f): # type: (Union[IO[bytes], Text]) -> bytes 31 if hasattr(f, 'read') and callable(cast(IO[bytes], f).read): 32 s = cast(IO[bytes], f).read() 33 else: 34 with open(cast(Text, f), 'rb') as readable: 35 s = readable.read() 36 return s 37 38 39# str should be bytes, 40# f should be either writable or a file path 41def _save_bytes(str, f): # type: (bytes, Union[IO[bytes], Text]) -> None 42 if hasattr(f, 'write') and callable(cast(IO[bytes], f).write): 43 cast(IO[bytes], f).write(str) 44 else: 45 with open(cast(Text, f), 'wb') as writable: 46 writable.write(str) 47 48 49# f should be either a readable file or a file path 50def _get_file_path(f): # type: (Union[IO[bytes], Text]) -> Optional[Text] 51 if isinstance(f, string_types): 52 return os.path.abspath(f) 53 if hasattr(f, 'name'): 54 return os.path.abspath(f.name) 55 return None 56 57 58def _serialize(proto): # type: (Union[bytes, google.protobuf.message.Message]) -> bytes 59 ''' 60 Serialize a in-memory proto to bytes 61 62 @params 63 proto is a in-memory proto, such as a ModelProto, TensorProto, etc 64 65 @return 66 Serialized proto in bytes 67 ''' 68 if isinstance(proto, bytes): 69 return proto 70 elif hasattr(proto, 'SerializeToString') and callable(proto.SerializeToString): 71 result = proto.SerializeToString() 72 return result 73 else: 74 raise TypeError('No SerializeToString method is detected. ' 75 'neither proto is a str.\ntype is {}'.format(type(proto))) 76 77 78_Proto = TypeVar('_Proto', bound=google.protobuf.message.Message) 79 80 81def _deserialize(s, proto): # type: (bytes, _Proto) -> _Proto 82 ''' 83 Parse bytes into a in-memory proto 84 85 @params 86 s is bytes containing serialized proto 87 proto is a in-memory proto object 88 89 @return 90 The proto instance filled in by s 91 ''' 92 if not isinstance(s, bytes): 93 raise ValueError('Parameter s must be bytes, but got type: {}'.format(type(s))) 94 95 if not (hasattr(proto, 'ParseFromString') and callable(proto.ParseFromString)): 96 raise ValueError('No ParseFromString method is detected. ' 97 '\ntype is {}'.format(type(proto))) 98 99 decoded = cast(Optional[int], proto.ParseFromString(s)) 100 if decoded is not None and decoded != len(s): 101 raise google.protobuf.message.DecodeError( 102 "Protobuf decoding consumed too few bytes: {} out of {}".format( 103 decoded, len(s))) 104 return proto 105 106 107def load_model(f, format=None, load_external_data=True): # type: (Union[IO[bytes], Text], Optional[Any], bool) -> ModelProto 108 ''' 109 Loads a serialized ModelProto into memory 110 load_external_data is true if the external data under the same directory of the model and load the external data 111 If not, users need to call load_external_data_for_model with directory to load 112 113 @params 114 f can be a file-like object (has "read" function) or a string containing a file name 115 format is for future use 116 117 @return 118 Loaded in-memory ModelProto 119 ''' 120 s = _load_bytes(f) 121 model = load_model_from_string(s, format=format) 122 123 if load_external_data: 124 model_filepath = _get_file_path(f) 125 if model_filepath: 126 base_dir = os.path.dirname(model_filepath) 127 load_external_data_for_model(model, base_dir) 128 129 return model 130 131 132def load_tensor(f, format=None): # type: (Union[IO[bytes], Text], Optional[Any]) -> TensorProto 133 ''' 134 Loads a serialized TensorProto into memory 135 136 @params 137 f can be a file-like object (has "read" function) or a string containing a file name 138 format is for future use 139 140 @return 141 Loaded in-memory TensorProto 142 ''' 143 s = _load_bytes(f) 144 return load_tensor_from_string(s, format=format) 145 146 147def load_model_from_string(s, format=None): # type: (bytes, Optional[Any]) -> ModelProto 148 ''' 149 Loads a binary string (bytes) that contains serialized ModelProto 150 151 @params 152 s is a string, which contains serialized ModelProto 153 format is for future use 154 155 @return 156 Loaded in-memory ModelProto 157 ''' 158 return _deserialize(s, ModelProto()) 159 160 161def load_tensor_from_string(s, format=None): # type: (bytes, Optional[Any]) -> TensorProto 162 ''' 163 Loads a binary string (bytes) that contains serialized TensorProto 164 165 @params 166 s is a string, which contains serialized TensorProto 167 format is for future use 168 169 @return 170 Loaded in-memory TensorProto 171 ''' 172 return _deserialize(s, TensorProto()) 173 174 175def save_model(proto, f, format=None, save_as_external_data=False, all_tensors_to_one_file=True, location=None, size_threshold=1024, convert_attribute=False): 176 # type: (Union[ModelProto, bytes], Union[IO[bytes], Text], Optional[Any], bool, bool, Optional[Text], int, bool) -> None 177 ''' 178 Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving. 179 180 @params 181 proto: should be a in-memory ModelProto 182 f: can be a file-like object (has "write" function) or a string containing a file name format for future use 183 all_tensors_to_one_file: If true, save all tensors to one external file specified by location. 184 If false, save each tensor to a file named with the tensor name. 185 location: specify the external file that all tensors to save to. 186 If not specified, will use the model name. 187 size_threshold: Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted 188 to external data. To convert every tensor with raw data to external data set size_threshold=0. 189 convert_attribute: If true, convert all tensors to external data 190 If false, convert only non-attribute tensors to external data 191 ''' 192 if isinstance(proto, bytes): 193 proto = _deserialize(proto, ModelProto()) 194 195 if save_as_external_data: 196 convert_model_to_external_data(proto, all_tensors_to_one_file, location, size_threshold, convert_attribute) 197 198 model_filepath = _get_file_path(f) 199 if model_filepath: 200 basepath = os.path.dirname(model_filepath) 201 proto = write_external_data_tensors(proto, basepath) 202 203 s = _serialize(proto) 204 _save_bytes(s, f) 205 206 207def save_tensor(proto, f): # type: (TensorProto, Union[IO[bytes], Text]) -> None 208 ''' 209 Saves the TensorProto to the specified path. 210 211 @params 212 proto should be a in-memory TensorProto 213 f can be a file-like object (has "write" function) or a string containing a file name 214 format is for future use 215 ''' 216 s = _serialize(proto) 217 _save_bytes(s, f) 218 219 220# For backward compatibility 221load = load_model 222load_from_string = load_model_from_string 223save = save_model 224