1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=broad-except
18"""Common utilities"""
19from __future__ import absolute_import as _abs
20import logging
21import numpy as np
22
23import tvm
24from tvm.ir import IRModule
25from tvm.topi.util import get_const_tuple
26
27from .. import expr as _expr
28from .. import function as _function
29from .. import transform as _transform
30from .. import op as _op
31from .. import analysis
32
33
34class RequiredAttr(object):
35    """Dummpy class to represent required attr"""
36
37
38class StrAttrsDict(object):
39    """Helper class to parse attrs stored as Dict[str, str].
40
41    Parameters
42    ----------
43    attrs : Dict[str, str]
44        The attributes to be used.
45    """
46
47    def __init__(self, attrs):
48        self.attrs = attrs
49
50    def has_attr(self, key):
51        """Checks if a attribute is present in the map.
52
53        Parameters
54        ----------
55        key : str
56            The attribute key
57
58        Returns
59        -------
60        bool : True if the key is present in the attributes else false.
61        """
62        return key in self.attrs
63
64    def get_float(self, key, default=RequiredAttr()):
65        """Get float attribute
66
67        Parameters
68        ----------
69        key : str
70            The attribute key
71
72        default : float
73            The default value.
74
75        Returns
76        -------
77        value : The result
78        """
79        if key in self.attrs:
80            return float(self.attrs[key])
81        if isinstance(default, RequiredAttr):
82            raise AttributeError("Required attribute {} not found.".format(key))
83        return default
84
85    def get_int(self, key, default=RequiredAttr()):
86        """Get int attribute
87
88        Parameters
89        ----------
90        key : str
91            The attribute key
92
93        default : float
94            The default value.
95
96        Returns
97        -------
98        value : The result
99        """
100        if key in self.attrs:
101            val = self.attrs[key]
102            if val == "None":
103                return None
104            return int(val)
105        if isinstance(default, RequiredAttr):
106            raise AttributeError("Required attribute {} not found.".format(key))
107        return default
108
109    def get_str(self, key, default=RequiredAttr()):
110        """Get str attribute
111
112        Parameters
113        ----------
114        key : str
115            The attribute key
116
117        default : float
118            The default value.
119
120        Returns
121        -------
122        value : The result
123        """
124        if key in self.attrs:
125            return self.attrs[key]
126        if isinstance(default, RequiredAttr):
127            raise AttributeError("Required attribute {} not found.".format(key))
128        return default
129
130    def get_int_tuple(self, key, default=RequiredAttr()):
131        """Get int tuple attribute
132
133        Parameters
134        ----------
135        key : str
136            The attribute key
137
138        default : float
139            The default value.
140
141        Returns
142        -------
143        value : The result
144        """
145        if key in self.attrs:
146            tshape = self.attrs[key]
147            return tuple(
148                int(x) if x.strip("- ").isdigit() else None
149                for x in tshape.strip("()[]").split(",")
150                if x
151            )
152        if isinstance(default, RequiredAttr):
153            raise AttributeError("Required attribute {} not found.".format(key))
154        return default
155
156    def get_float_tuple(self, key, default=RequiredAttr()):
157        """Get float tuple attribute
158
159        Parameters
160        ----------
161        key : str
162            The attribute key
163
164        default : float
165            The default value.
166
167        Returns
168        -------
169        value : The result
170        """
171
172        if key in self.attrs:
173            tshape = self.attrs[key]
174            return tuple(float(x.strip()) for x in tshape.strip("()[]").split(","))
175        if isinstance(default, RequiredAttr):
176            raise AttributeError("Required attribute {} not found.".format(key))
177        return default
178
179    def get_tuple_tuple_int(self, key, default=RequiredAttr()):
180        """Get int list attribute
181
182        Parameters
183        ----------
184        key : str
185            The attribute key
186
187        default : float
188            The default value.
189
190        Returns
191        -------
192        value : The result
193        """
194        if key in self.attrs:
195            value = self.attrs[key]
196            seq = []
197            for tup in value.strip("()").split("),"):
198                tup = tup.strip("[]()")
199                els = [int(x.strip("( ")) for x in tup.split(",")]
200                seq.append(tuple(els))
201
202            return tuple(seq)
203
204        if isinstance(default, RequiredAttr):
205            raise AttributeError("Required attribute {} not found.".format(key))
206        return default
207
208    def get_int_list(self, key, default=RequiredAttr()):
209        """Get int list attribute
210
211        Parameters
212        ----------
213        key : str
214            The attribute key
215
216        default : float
217            The default value.
218
219        Returns
220        -------
221        value : The result
222        """
223        if key in self.attrs:
224            tshape = self.attrs[key]
225            return tuple(int(x.strip()) for x in tshape.strip("[]()").split(","))
226        if isinstance(default, RequiredAttr):
227            raise AttributeError("Required attribute {} not found.".format(key))
228        return default
229
230    def get_bool(self, key, default=RequiredAttr()):
231        """Get bool tuple attribute
232
233        Parameters
234        ----------
235        key : str
236            The attribute key
237
238        default : float
239            The default value.
240
241        Returns
242        -------
243        value : The result
244        """
245        if key in self.attrs:
246            val = self.attrs[key]
247            return val.strip().lower() in ["true", "1", "t", "y", "yes"]
248        if isinstance(default, RequiredAttr):
249            raise AttributeError("Required attribute {} not found.".format(key))
250        return default
251
252
253def get_relay_op(op_name):
254    """Get the callable function from Relay based on operator name.
255    Parameters
256    ----------
257    op_name : str
258        The Relay operator name.
259    """
260    if "." in op_name:
261        # explicit hierachical modules
262        op = _op
263        try:
264            for opn in op_name.split("."):
265                op = getattr(op, opn)
266        except AttributeError:
267            op = None
268    else:
269        # try search op in various modules
270        for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib):
271            op = getattr(candidate, op_name, None)
272            if op is not None:
273                break
274    if not op:
275        raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))
276    return op
277
278
279class ExprTable(object):
280    """Table storing Relay expressions by names."""
281
282    def __init__(self):
283        self.exprs = {}
284        self.params = {}
285        self.const_ctr = 1
286        self.in_padding = False
287
288    def new_const(self, value, shape=None, dtype="float32"):
289        name = "_param_%d" % (self.const_ctr)
290        if hasattr(value, "shape"):
291            shape = value.shape
292        self.const_ctr += 1
293        self.params[name] = value
294        self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype)
295        return self.exprs[name]
296
297    def get_expr(self, name):
298        return self.exprs[name]
299
300    def set_expr(self, name, expr, force_override=False):
301        assert isinstance(expr, _expr.Expr)
302        # if name exists, we should override the value
303        # otherwise, we can not get like x = func(x) work.
304        # One example is CoreML preprocess, which will override
305        # the same name of input.
306        # However, according to git log, Find keras frontend depends
307        # on this property, so we add one force_override to control it.
308        if name not in self.exprs or force_override:
309            self.exprs[name] = expr
310
311    def has_expr(self, name):
312        return name in self.exprs
313
314    def set_padding(self, paddings):
315        self.paddings = paddings
316        self.in_padding = True
317
318    def clear_padding(self):
319        self.in_padding = False
320
321
322class AttrCvt(object):
323    """Common attribute converter. An AttrConverter instance is a callable:
324    ```
325    attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
326    new_op_name, new_attr = attr_converter(attrs)
327    ```
328
329    Parameters
330    ----------
331    op_name : str or callable
332        If set as str, returned operator name is the str.
333        If set as callable, returned operator is the str returned by calling:
334        `op_name = func(attr)`
335
336    transforms : dict of `new_name, or (new_name, default_value, transform function)`
337        If only a new_name is provided, it's like renaming the attribute name.
338        If default_value if provided, then the attribute is considered as optional.
339        If transform function is provided, the original attribute value is handled
340        by transform function.
341
342    excludes : list
343        A list of excluded attributes that should `NOT` appear.
344        Raise NotImplementedError if occurred.
345
346    disables : list
347        A list of attributes that is disabled in relay. Log warnings.
348
349    ignores : list
350        A list of attributes that is ignored in relay. Debug level logging.
351
352    extras : dict
353        A series of additional attributes should be added anyway to the returned
354        attribute dict.
355
356    custom_check : callable
357        A custom function takes attribute, and return True/False.
358        Raise RuntimeError if not bool(True) returned.
359    """
360
361    def __init__(
362        self,
363        op_name,
364        transforms=None,
365        excludes=None,
366        disables=None,
367        ignores=None,
368        extras=None,
369        custom_check=None,
370    ):
371        self._op_name = op_name
372        self._transforms = transforms if transforms else {}
373        self._excludes = excludes if excludes else []
374        self._disables = disables if disables else []
375        self._ignores = ignores if ignores else []
376        self._extras = extras if extras else {}
377        self._custom_check = custom_check
378
379    def __call__(self, inputs, attrs, *args):
380        self._ignores.append("_output_shapes")
381        self._ignores.append("_input_shapes")
382        self._ignores.append("T")
383        self._ignores.append("use_cudnn_on_gpu")
384        self._ignores.append("_node_name")
385        self._ignores.append("is_training")
386        self._ignores.append("_target_layout")
387
388        # apply custom check
389        if self._custom_check:
390            func, msg = self._custom_check
391            if not func(attrs):
392                raise RuntimeError("Check failed: {}".format(msg))
393        # get new op_name
394        if isinstance(self._op_name, str):
395            op_name = self._op_name
396        else:
397            assert callable(self._op_name), "op_name can either be string or callable"
398            op_name = self._op_name(attrs)
399
400        # ignore 'tvm_custom' always
401        self._ignores.append("tvm_custom")
402
403        # convert attributes
404        new_attrs = {}
405        for k in attrs.keys():
406            if k in self._excludes:
407                raise NotImplementedError(
408                    "Attribute %s in operator %s is not" + " supported.", k, op_name
409                )
410            if k in self._disables:
411                logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
412            elif k in self._ignores:
413                if k != "tvm_custom":
414                    logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name)
415            elif k in self._transforms:
416                new_name, defaults, transform = self._parse_default(self._transforms[k])
417                if defaults is None:
418                    new_attr = self._required_attr(attrs, k)
419                else:
420                    new_attr = attrs.get(k, None)
421                if new_attr is None:
422                    new_attrs[new_name] = defaults
423                else:
424                    new_attrs[new_name] = transform(new_attr)
425            else:
426                # copy
427                new_attrs[k] = attrs[k]
428        # add extras
429        new_attrs.update(self._extras)
430        return get_relay_op(op_name)(*inputs, **new_attrs)
431
432    def _parse_default(self, target):
433        """Helper function to parse default values."""
434        if not isinstance(target, (list, tuple)):
435            k, v, t = target, None, lambda x: x
436        elif len(target) == 1:
437            k, v, t = target[0], None, lambda x: x
438        elif len(target) == 2:
439            k, v, t = target[0], target[1], lambda x: x
440        elif len(target) > 2:
441            k, v, t = target[0], target[1], target[2]
442        else:
443            k = None  # should raise
444        if not isinstance(k, str):
445            msg = "{} is not a valid target, (name, default) expected.".format(target)
446            raise ValueError(msg)
447        return k, v, t
448
449    def _parse_bool(self, value):
450        """Helper function to parse default boolean values."""
451        if isinstance(value, str):
452            return value.strip().lower() in ["true", "1", "t", "y", "yes"]
453        return bool(value)
454
455    def _required_attr(self, attr, key):
456        """Wrapper for getting required attributes."""
457        assert isinstance(attr, dict)
458        if key not in attr:
459            raise AttributeError("Required attribute {} not found.".format(key))
460        return attr[key]
461
462
463def get_name(node):
464    name = ""
465    if hasattr(node, "name_hint"):
466        name = node.name_hint
467    return name
468
469
470def infer_type(node, mod=None):
471    """A method to infer the type of an intermediate node in the relay graph."""
472    if isinstance(mod, IRModule):
473        mod["main"] = _function.Function(tvm.relay.analysis.free_vars(node), node)
474        mod = _transform.InferType()(mod)
475        entry = mod["main"]
476        ret = entry.body
477    else:
478        new_mod = IRModule.from_expr(node)
479        if mod is not None:
480            new_mod.update(mod)
481            new_mod = _transform.InferType()(new_mod)
482        entry = new_mod["main"]
483        ret = entry if isinstance(node, _function.Function) else entry.body
484
485    return ret
486
487
488def infer_channels(inputs, transpose=False):
489    """A hack for getting 'channels' or 'units' since caffe2 does not provide
490    these attributes. We check the shape of weights provided to get the number.
491    """
492    out_type = infer_type(inputs)
493    out_shapes = [get_const_tuple(out_type.checked_type.shape)]
494    channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
495    return channels
496
497
498def infer_shape(inputs, mod=None):
499    """A method to get the output type of an intermediate node in the graph."""
500    out_type = infer_type(inputs, mod=mod)
501    checked_type = out_type.checked_type
502    if hasattr(checked_type, "shape"):
503        # Regular operator that outputs tensors
504        return get_const_tuple(checked_type.shape)
505    # The return type is not a tensor, for example List
506    return checked_type
507
508
509def infer_value(input_val, params, mod=None):
510    """A hack for getting the value of an expression by evaluating a
511    portion of the relay graph. This is often needed for functions that
512    whose output shape depends on the value of a tensor.
513    """
514    # Check that all free variables have associated parameters.
515    assert all(
516        var.name_hint in params.keys() for var in analysis.free_vars(input_val)
517    ), "All inputs to infer must be available in params."
518    try:
519        # TODO(kevinthesun): Use VM for all cases.
520        # pylint: disable=import-outside-toplevel
521        from tvm.contrib import graph_runtime
522
523        func = _function.Function(analysis.free_vars(input_val), input_val)
524        with tvm.transform.PassContext(opt_level=0):
525            lib = tvm.relay.build(func, target="llvm", params=params)
526        ctx = tvm.cpu(0)
527        m = graph_runtime.GraphModule(lib["default"](ctx))
528        m.run()
529        return m.get_output(0)
530    except Exception:
531        if isinstance(mod, IRModule):
532            mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
533        else:
534            mod = IRModule.from_expr(input_val)
535        exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
536        inputs = []
537        for param in mod["main"].params:
538            inputs.append(params[param.name_hint])
539        result = exc.evaluate()(*inputs)
540        return result
541
542
543def infer_value_simulated(input_val, params):
544    """Extention to infer_value that can be used when some input
545    values are missing. This function creates dummy inputs with the same
546    shape and random values then calls infer_value. This is helpful when
547    implementing certain onnx operators where we need to evaluate the graph
548    to determine a static shape.
549    """
550    fake_params = []
551    # Add a fake copy of all missing params.
552    for free_param in analysis.free_vars(input_val):
553        if free_param.name_hint not in params:
554            fp_dtype = free_param.type_annotation.dtype
555            fp_shape = [s.value for s in free_param.type_annotation.shape]
556            fake_params.append(free_param)
557            params[free_param.name_hint] = tvm.nd.array(np.random.rand(*fp_shape).astype(fp_dtype))
558    # Now infer the value.
559    output_value = infer_value(input_val, params)
560    # Clean fake params out of param dictionary.
561    for fake_p in fake_params:
562        params.pop(fake_p.name_hint, None)
563    return output_value
564
565
566def try_infer_value(val, on_success=None, on_failure=None):
567    """Try running infer_value on the input val, and if successful, return the inferred value or
568    pass it to on_success callback if provided. Otherwise, run on_failure callback if it is
569    provided, or return the input val as output. In each case, the second return value
570    indicates whether infer_value has succeeded or not.
571    """
572    try:
573        ret = infer_value(val, {}).asnumpy()
574        if on_success:
575            return on_success(ret), True
576        return ret, True
577    except Exception:
578        if on_failure:
579            return on_failure(), False
580        return val, False
581
582
583def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"):
584    return _expr.var(name_hint, type_annotation, shape, dtype)
585
586
587class Renamer(object):
588    """A simply renamer for operators.
589
590    Parameters
591    ----------
592    new_name : str
593        The new name for the operator
594    """
595
596    def __init__(self, new_name):
597        self._new_name = new_name
598
599    def __call__(self, inputs, attrs, *args):
600        if "tvm_custom" in attrs:
601            attrs.pop("tvm_custom")
602        return get_relay_op(self._new_name)(*inputs, **attrs)
603