1from __future__ import absolute_import 2from __future__ import division 3from __future__ import print_function 4from __future__ import unicode_literals 5 6import inspect 7import re 8import sys 9import uuid 10import warnings 11import logging 12 13from onnx.backend.base import DeviceType 14from tensorflow.python.client import device_lib 15 16IS_PYTHON3 = sys.version_info > (3,) 17logger = logging.getLogger('onnx-tf') 18 19# create console handler and formatter for logger 20console = logging.StreamHandler() 21formatter = logging.Formatter( 22 '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 23console.setFormatter(formatter) 24logger.addHandler(console) 25 26 27class Deprecated: 28 """Add deprecated message when function is called. 29 30 Usage: 31 from onnx_tf.common import deprecated 32 33 @deprecated 34 def func(): 35 pass 36 37 UserWarning: func is deprecated. It will be removed in future release. 38 39 @deprecated("Message") 40 def func(): 41 pass 42 43 UserWarning: Message 44 45 @deprecated({"arg": "Message", 46 "arg_1": deprecated.MSG_WILL_REMOVE, 47 "arg_2": "",}) 48 def func(arg, arg_1, arg_2): 49 pass 50 51 UserWarning: Message 52 UserWarning: arg_1 of func is deprecated. It will be removed in future release. 53 UserWarning: arg_2 of func is deprecated. 54 """ 55 56 MSG_WILL_REMOVE = " It will be removed in future release." 57 58 def __call__(self, *args, **kwargs): 59 return self.deprecated_decorator(*args, **kwargs) 60 61 @staticmethod 62 def messages(): 63 return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")} 64 65 @staticmethod 66 def deprecated_decorator(arg=None): 67 # deprecate function with default message MSG_WILL_REMOVE 68 # @deprecated 69 if inspect.isfunction(arg): 70 71 def wrapper(*args, **kwargs): 72 warnings.warn("{} is deprecated.{}".format( 73 arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE)) 74 return arg(*args, **kwargs) 75 76 return wrapper 77 78 deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE 79 80 def deco(func): 81 # deprecate arg 82 # @deprecated({...}) 83 if isinstance(deprecated_arg, dict): 84 for name, message in deprecated_arg.items(): 85 if message in Deprecated.messages(): 86 message = "{} of {} is deprecated.{}".format( 87 name, func.__module__ + "." + func.__name__, message or "") 88 warnings.warn(message) 89 # deprecate function with message 90 # @deprecated("message") 91 elif isinstance(deprecated_arg, str): 92 message = deprecated_arg 93 if message in Deprecated.messages(): 94 message = "{} is deprecated.{}".format( 95 func.__module__ + "." + func.__name__, message) 96 warnings.warn(message) 97 return func 98 99 return deco 100 101 102deprecated = Deprecated() 103 104 105# This function inserts an underscore before every upper 106# case letter and lowers that upper case letter except for 107# the first letter. 108def op_name_to_lower(name): 109 return re.sub('(?<!^)(?=[A-Z])', '_', name).lower() 110 111 112def get_unique_suffix(): 113 """ Get unique suffix by using first 8 chars from uuid.uuid4 114 to make unique identity name. 115 116 :return: Unique suffix string. 117 """ 118 return str(uuid.uuid4())[:8] 119 120 121def get_perm_from_formats(from_, to_): 122 """ Get perm from data formats. 123 For example: 124 get_perm_from_formats('NHWC', 'NCHW') = [0, 3, 1, 2] 125 126 :param from_: From data format string. 127 :param to_: To data format string. 128 :return: Perm. Int list. 129 """ 130 return list(map(lambda x: from_.find(x), to_)) 131 132 133# TODO: allow more flexible placement 134def get_device_option(device): 135 m = {DeviceType.CPU: '/cpu', DeviceType.CUDA: '/gpu'} 136 return m[device.type] 137 138 139def get_data_format(x_rank): 140 """ Get data format by input rank. 141 Channel first if support CUDA. 142 143 :param x_rank: Input rank. 144 :return: Data format. 145 """ 146 sp_dim_names = ["D", "H", "W"] 147 sp_dim_lst = [] 148 for i in range(x_rank - 2): 149 sp_dim_lst.append(sp_dim_names[-i - 1]) 150 151 sp_dim_string = "".join(reversed(sp_dim_lst)) 152 storage_format = "NC" + sp_dim_string 153 154 if supports_device("CUDA"): 155 compute_format = "NC" + sp_dim_string 156 else: 157 compute_format = "N" + sp_dim_string + "C" 158 return storage_format, compute_format 159 160 161def supports_device(device): 162 """ Check if support target device. 163 164 :param device: CUDA or CPU. 165 :return: If supports. 166 """ 167 if device == "CUDA": 168 local_device_protos = device_lib.list_local_devices() 169 return len([x.name for x in local_device_protos if x.device_type == 'GPU' 170 ]) > 0 171 elif device == "CPU": 172 return True 173 return False 174 175 176CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32" 177CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32" 178CONST_ONE_INT32 = "_onnx_tf_internal_one_int32" 179CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32" 180