1from __future__ import absolute_import
2from __future__ import division
3from __future__ import print_function
4from __future__ import unicode_literals
5
6import inspect
7import re
8import sys
9import uuid
10import warnings
11import logging
12
13from onnx.backend.base import DeviceType
14from tensorflow.python.client import device_lib
15
16IS_PYTHON3 = sys.version_info > (3,)
17logger = logging.getLogger('onnx-tf')
18
19# create console handler and formatter for logger
20console = logging.StreamHandler()
21formatter = logging.Formatter(
22    '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
23console.setFormatter(formatter)
24logger.addHandler(console)
25
26
27class Deprecated:
28  """Add deprecated message when function is called.
29
30  Usage:
31    from onnx_tf.common import deprecated
32
33    @deprecated
34    def func():
35      pass
36
37    UserWarning: func is deprecated. It will be removed in future release.
38
39    @deprecated("Message")
40    def func():
41      pass
42
43    UserWarning: Message
44
45    @deprecated({"arg": "Message",
46                 "arg_1": deprecated.MSG_WILL_REMOVE,
47                 "arg_2": "",})
48    def func(arg, arg_1, arg_2):
49      pass
50
51    UserWarning: Message
52    UserWarning: arg_1 of func is deprecated. It will be removed in future release.
53    UserWarning: arg_2 of func is deprecated.
54  """
55
56  MSG_WILL_REMOVE = " It will be removed in future release."
57
58  def __call__(self, *args, **kwargs):
59    return self.deprecated_decorator(*args, **kwargs)
60
61  @staticmethod
62  def messages():
63    return {v for k, v in inspect.getmembers(Deprecated) if k.startswith("MSG")}
64
65  @staticmethod
66  def deprecated_decorator(arg=None):
67    # deprecate function with default message MSG_WILL_REMOVE
68    # @deprecated
69    if inspect.isfunction(arg):
70
71      def wrapper(*args, **kwargs):
72        warnings.warn("{} is deprecated.{}".format(
73            arg.__module__ + "." + arg.__name__, Deprecated.MSG_WILL_REMOVE))
74        return arg(*args, **kwargs)
75
76      return wrapper
77
78    deprecated_arg = arg if arg is not None else Deprecated.MSG_WILL_REMOVE
79
80    def deco(func):
81      # deprecate arg
82      # @deprecated({...})
83      if isinstance(deprecated_arg, dict):
84        for name, message in deprecated_arg.items():
85          if message in Deprecated.messages():
86            message = "{} of {} is deprecated.{}".format(
87                name, func.__module__ + "." + func.__name__, message or "")
88          warnings.warn(message)
89      # deprecate function with message
90      # @deprecated("message")
91      elif isinstance(deprecated_arg, str):
92        message = deprecated_arg
93        if message in Deprecated.messages():
94          message = "{} is deprecated.{}".format(
95              func.__module__ + "." + func.__name__, message)
96        warnings.warn(message)
97      return func
98
99    return deco
100
101
102deprecated = Deprecated()
103
104
105# This function inserts an underscore before every upper
106# case letter and lowers that upper case letter except for
107# the first letter.
108def op_name_to_lower(name):
109  return re.sub('(?<!^)(?=[A-Z])', '_', name).lower()
110
111
112def get_unique_suffix():
113  """ Get unique suffix by using first 8 chars from uuid.uuid4
114  to make unique identity name.
115
116  :return: Unique suffix string.
117  """
118  return str(uuid.uuid4())[:8]
119
120
121def get_perm_from_formats(from_, to_):
122  """ Get perm from data formats.
123  For example:
124    get_perm_from_formats('NHWC', 'NCHW') = [0, 3, 1, 2]
125
126  :param from_: From data format string.
127  :param to_: To data format string.
128  :return: Perm. Int list.
129  """
130  return list(map(lambda x: from_.find(x), to_))
131
132
133# TODO: allow more flexible placement
134def get_device_option(device):
135  m = {DeviceType.CPU: '/cpu', DeviceType.CUDA: '/gpu'}
136  return m[device.type]
137
138
139def get_data_format(x_rank):
140  """ Get data format by input rank.
141  Channel first if support CUDA.
142
143  :param x_rank: Input rank.
144  :return: Data format.
145  """
146  sp_dim_names = ["D", "H", "W"]
147  sp_dim_lst = []
148  for i in range(x_rank - 2):
149    sp_dim_lst.append(sp_dim_names[-i - 1])
150
151  sp_dim_string = "".join(reversed(sp_dim_lst))
152  storage_format = "NC" + sp_dim_string
153
154  if supports_device("CUDA"):
155    compute_format = "NC" + sp_dim_string
156  else:
157    compute_format = "N" + sp_dim_string + "C"
158  return storage_format, compute_format
159
160
161def supports_device(device):
162  """ Check if support target device.
163
164  :param device: CUDA or CPU.
165  :return: If supports.
166  """
167  if device == "CUDA":
168    local_device_protos = device_lib.list_local_devices()
169    return len([x.name for x in local_device_protos if x.device_type == 'GPU'
170               ]) > 0
171  elif device == "CPU":
172    return True
173  return False
174
175
176CONST_MINUS_ONE_INT32 = "_onnx_tf_internal_minus_one_int32"
177CONST_ZERO_INT32 = "_onnx_tf_internal_zero_int32"
178CONST_ONE_INT32 = "_onnx_tf_internal_one_int32"
179CONST_ONE_FP32 = "_onnx_tf_internal_one_fp32"
180