1# SPDX-License-Identifier: Apache-2.0
2
3from __future__ import absolute_import
4from __future__ import division
5from __future__ import print_function
6from __future__ import unicode_literals
7
8import os
9
10from .onnx_cpp2py_export import ONNX_ML
11from onnx.external_data_helper import load_external_data_for_model, write_external_data_tensors, convert_model_to_external_data
12from .onnx_pb import *  # noqa
13from .onnx_operators_pb import * # noqa
14from .onnx_data_pb import * # noqa
15from .version import version as __version__  # noqa
16
17# Import common subpackages so they're available when you 'import onnx'
18import onnx.checker  # noqa
19import onnx.defs  # noqa
20import onnx.helper  # noqa
21import onnx.utils  # noqa
22
23import google.protobuf.message
24
25from typing import Union, Text, IO, Optional, cast, TypeVar, Any
26from six import string_types
27
28
29# f should be either readable or a file path
30def _load_bytes(f):  # type: (Union[IO[bytes], Text]) -> bytes
31    if hasattr(f, 'read') and callable(cast(IO[bytes], f).read):
32        s = cast(IO[bytes], f).read()
33    else:
34        with open(cast(Text, f), 'rb') as readable:
35            s = readable.read()
36    return s
37
38
39# str should be bytes,
40# f should be either writable or a file path
41def _save_bytes(str, f):  # type: (bytes, Union[IO[bytes], Text]) -> None
42    if hasattr(f, 'write') and callable(cast(IO[bytes], f).write):
43        cast(IO[bytes], f).write(str)
44    else:
45        with open(cast(Text, f), 'wb') as writable:
46            writable.write(str)
47
48
49# f should be either a readable file or a file path
50def _get_file_path(f):  # type: (Union[IO[bytes], Text]) -> Optional[Text]
51    if isinstance(f, string_types):
52        return os.path.abspath(f)
53    if hasattr(f, 'name'):
54        return os.path.abspath(f.name)
55    return None
56
57
58def _serialize(proto):  # type: (Union[bytes, google.protobuf.message.Message]) -> bytes
59    '''
60    Serialize a in-memory proto to bytes
61
62    @params
63    proto is a in-memory proto, such as a ModelProto, TensorProto, etc
64
65    @return
66    Serialized proto in bytes
67    '''
68    if isinstance(proto, bytes):
69        return proto
70    elif hasattr(proto, 'SerializeToString') and callable(proto.SerializeToString):
71        result = proto.SerializeToString()
72        return result
73    else:
74        raise TypeError('No SerializeToString method is detected. '
75                         'neither proto is a str.\ntype is {}'.format(type(proto)))
76
77
78_Proto = TypeVar('_Proto', bound=google.protobuf.message.Message)
79
80
81def _deserialize(s, proto):  # type: (bytes, _Proto) -> _Proto
82    '''
83    Parse bytes into a in-memory proto
84
85    @params
86    s is bytes containing serialized proto
87    proto is a in-memory proto object
88
89    @return
90    The proto instance filled in by s
91    '''
92    if not isinstance(s, bytes):
93        raise ValueError('Parameter s must be bytes, but got type: {}'.format(type(s)))
94
95    if not (hasattr(proto, 'ParseFromString') and callable(proto.ParseFromString)):
96        raise ValueError('No ParseFromString method is detected. '
97                         '\ntype is {}'.format(type(proto)))
98
99    decoded = cast(Optional[int], proto.ParseFromString(s))
100    if decoded is not None and decoded != len(s):
101        raise google.protobuf.message.DecodeError(
102            "Protobuf decoding consumed too few bytes: {} out of {}".format(
103                decoded, len(s)))
104    return proto
105
106
107def load_model(f, format=None, load_external_data=True):  # type: (Union[IO[bytes], Text], Optional[Any], bool) -> ModelProto
108    '''
109    Loads a serialized ModelProto into memory
110    load_external_data is true if the external data under the same directory of the model and load the external data
111    If not, users need to call load_external_data_for_model with directory to load
112
113    @params
114    f can be a file-like object (has "read" function) or a string containing a file name
115    format is for future use
116
117    @return
118    Loaded in-memory ModelProto
119    '''
120    s = _load_bytes(f)
121    model = load_model_from_string(s, format=format)
122
123    if load_external_data:
124        model_filepath = _get_file_path(f)
125        if model_filepath:
126            base_dir = os.path.dirname(model_filepath)
127            load_external_data_for_model(model, base_dir)
128
129    return model
130
131
132def load_tensor(f, format=None):  # type: (Union[IO[bytes], Text], Optional[Any]) -> TensorProto
133    '''
134    Loads a serialized TensorProto into memory
135
136    @params
137    f can be a file-like object (has "read" function) or a string containing a file name
138    format is for future use
139
140    @return
141    Loaded in-memory TensorProto
142    '''
143    s = _load_bytes(f)
144    return load_tensor_from_string(s, format=format)
145
146
147def load_model_from_string(s, format=None):  # type: (bytes, Optional[Any]) -> ModelProto
148    '''
149    Loads a binary string (bytes) that contains serialized ModelProto
150
151    @params
152    s is a string, which contains serialized ModelProto
153    format is for future use
154
155    @return
156    Loaded in-memory ModelProto
157    '''
158    return _deserialize(s, ModelProto())
159
160
161def load_tensor_from_string(s, format=None):  # type: (bytes, Optional[Any]) -> TensorProto
162    '''
163    Loads a binary string (bytes) that contains serialized TensorProto
164
165    @params
166    s is a string, which contains serialized TensorProto
167    format is for future use
168
169    @return
170    Loaded in-memory TensorProto
171    '''
172    return _deserialize(s, TensorProto())
173
174
175def save_model(proto, f, format=None, save_as_external_data=False, all_tensors_to_one_file=True, location=None, size_threshold=1024, convert_attribute=False):
176    # type: (Union[ModelProto, bytes], Union[IO[bytes], Text], Optional[Any], bool, bool, Optional[Text], int, bool) -> None
177    '''
178    Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving.
179
180    @params
181    proto: should be a in-memory ModelProto
182    f: can be a file-like object (has "write" function) or a string containing a file name format for future use
183    all_tensors_to_one_file: If true, save all tensors to one external file specified by location.
184                             If false, save each tensor to a file named with the tensor name.
185    location: specify the external file that all tensors to save to.
186              If not specified, will use the model name.
187    size_threshold: Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted
188                    to external data. To convert every tensor with raw data to external data set size_threshold=0.
189    convert_attribute: If true, convert all tensors to external data
190                       If false, convert only non-attribute tensors to external data
191    '''
192    if isinstance(proto, bytes):
193        proto = _deserialize(proto, ModelProto())
194
195    if save_as_external_data:
196        convert_model_to_external_data(proto, all_tensors_to_one_file, location, size_threshold, convert_attribute)
197
198    model_filepath = _get_file_path(f)
199    if model_filepath:
200        basepath = os.path.dirname(model_filepath)
201        proto = write_external_data_tensors(proto, basepath)
202
203    s = _serialize(proto)
204    _save_bytes(s, f)
205
206
207def save_tensor(proto, f):  # type: (TensorProto, Union[IO[bytes], Text]) -> None
208    '''
209    Saves the TensorProto to the specified path.
210
211    @params
212    proto should be a in-memory TensorProto
213    f can be a file-like object (has "write" function) or a string containing a file name
214    format is for future use
215    '''
216    s = _serialize(proto)
217    _save_bytes(s, f)
218
219
220# For backward compatibility
221load = load_model
222load_from_string = load_model_from_string
223save = save_model
224