1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Common utilities"""
18from __future__ import absolute_import as _abs
19import logging
20
21import tvm
22import numpy as np
23from topi.util import get_const_tuple
24from .. import expr as _expr
25from .. import module as _module
26from .. import transform as _transform
27from .. import op as _op
28from .. import analysis
29
30
31class RequiredAttr(object):
32    """Dummpy class to represent required attr"""
33
34
35class StrAttrsDict(object):
36    """Helper class to parse attrs stored as Dict[str, str].
37
38    Parameters
39    ----------
40    attrs : Dict[str, str]
41        The attributes to be used.
42    """
43    def __init__(self, attrs):
44        self.attrs = attrs
45
46    def has_attr(self, key):
47        """Checks if a attribute is present in the map.
48
49        Parameters
50        ----------
51        key : str
52            The attribute key
53
54        Returns
55        -------
56        bool : True if the key is present in the attributes else false.
57        """
58        return key in self.attrs
59
60    def get_float(self, key, default=RequiredAttr()):
61        """Get float attribute
62
63        Parameters
64        ----------
65        key : str
66            The attribute key
67
68        default : float
69            The default value.
70
71        Returns
72        -------
73        value : The result
74        """
75        if key in self.attrs:
76            return float(self.attrs[key])
77        if isinstance(default, RequiredAttr):
78            raise AttributeError("Required attribute {} not found.".format(key))
79        return default
80
81    def get_int(self, key, default=RequiredAttr()):
82        """Get int attribute
83
84        Parameters
85        ----------
86        key : str
87            The attribute key
88
89        default : float
90            The default value.
91
92        Returns
93        -------
94        value : The result
95        """
96        if key in self.attrs:
97            val = self.attrs[key]
98            if val == "None":
99                return None
100            return int(val)
101        if isinstance(default, RequiredAttr):
102            raise AttributeError("Required attribute {} not found.".format(key))
103        return default
104
105    def get_str(self, key, default=RequiredAttr()):
106        """Get str attribute
107
108        Parameters
109        ----------
110        key : str
111            The attribute key
112
113        default : float
114            The default value.
115
116        Returns
117        -------
118        value : The result
119        """
120        if key in self.attrs:
121            return self.attrs[key]
122        if isinstance(default, RequiredAttr):
123            raise AttributeError("Required attribute {} not found.".format(key))
124        return default
125
126    def get_int_tuple(self, key, default=RequiredAttr()):
127        """Get int tuple attribute
128
129        Parameters
130        ----------
131        key : str
132            The attribute key
133
134        default : float
135            The default value.
136
137        Returns
138        -------
139        value : The result
140        """
141        if key in self.attrs:
142            tshape = self.attrs[key]
143            return tuple(int(x) if x.strip("- ").isdigit() else None
144                         for x in tshape.strip('()[]').split(',') if x)
145        if isinstance(default, RequiredAttr):
146            raise AttributeError("Required attribute {} not found.".format(key))
147        return default
148
149    def get_float_tuple(self, key, default=RequiredAttr()):
150        """Get float tuple attribute
151
152        Parameters
153        ----------
154        key : str
155            The attribute key
156
157        default : float
158            The default value.
159
160        Returns
161        -------
162        value : The result
163        """
164
165        if key in self.attrs:
166            tshape = self.attrs[key]
167            return tuple(float(x.strip()) for x in
168                         tshape.strip('()[]').split(','))
169        if isinstance(default, RequiredAttr):
170            raise AttributeError("Required attribute {} not found.".format(key))
171        return default
172
173    def get_tuple_tuple_int(self, key, default=RequiredAttr()):
174        """Get int list attribute
175
176        Parameters
177        ----------
178        key : str
179            The attribute key
180
181        default : float
182            The default value.
183
184        Returns
185        -------
186        value : The result
187        """
188        if key in self.attrs:
189            value = self.attrs[key]
190            seq = []
191            for tup in value.strip('()').split('),'):
192                tup = tup.strip('[]()')
193                els = [int(x.strip('( ')) for x in tup.split(',')]
194                seq.append(tuple(els))
195
196            return tuple(seq)
197
198        if isinstance(default, RequiredAttr):
199            raise AttributeError("Required attribute {} not found.".format(key))
200        return default
201
202    def get_int_list(self, key, default=RequiredAttr()):
203        """Get int list attribute
204
205        Parameters
206        ----------
207        key : str
208            The attribute key
209
210        default : float
211            The default value.
212
213        Returns
214        -------
215        value : The result
216        """
217        if key in self.attrs:
218            tshape = self.attrs[key]
219            return tuple(int(x.strip()) for x in tshape.strip('[]()').split(','))
220        if isinstance(default, RequiredAttr):
221            raise AttributeError("Required attribute {} not found.".format(key))
222        return default
223
224    def get_bool(self, key, default=RequiredAttr()):
225        """Get bool tuple attribute
226
227        Parameters
228        ----------
229        key : str
230            The attribute key
231
232        default : float
233            The default value.
234
235        Returns
236        -------
237        value : The result
238        """
239        if key in self.attrs:
240            val = self.attrs[key]
241            return val.strip().lower() in ['true', '1', 't', 'y', 'yes']
242        if isinstance(default, RequiredAttr):
243            raise AttributeError("Required attribute {} not found.".format(key))
244        return default
245
246
247def get_relay_op(op_name):
248    """Get the callable function from Relay based on operator name.
249    Parameters
250    ----------
251    op_name : str
252        The Relay operator name.
253    """
254    if '.' in op_name:
255        # explicit hierachical modules
256        op = _op
257        try:
258            for opn in op_name.split('.'):
259                op = getattr(op, opn)
260        except AttributeError:
261            op = None
262    else:
263        # try search op in various modules
264        for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib):
265            op = getattr(candidate, op_name, None)
266            if op is not None:
267                break
268    if not op:
269        raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name))
270    return op
271
272
273class ExprTable(object):
274    """Table storing Relay expressions by names."""
275    def __init__(self):
276        self.exprs = {}
277        self.params = {}
278        self.const_ctr = 1
279        self.in_padding = False
280
281    def new_const(self, value, shape=None, dtype="float32"):
282        name = "_param_%d" % (self.const_ctr)
283        if hasattr(value, "shape"):
284            shape = value.shape
285        self.const_ctr += 1
286        self.params[name] = value
287        self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype)
288        return self.exprs[name]
289
290    def get_expr(self, name):
291        return self.exprs[name]
292
293    def set_expr(self, name, expr, force_override=False):
294        assert isinstance(expr, _expr.Expr)
295        # if name exists, we should override the value
296        # otherwise, we can not get like x = func(x) work.
297        # One example is CoreML preprocess, which will override
298        # the same name of input.
299        # However, according to git log, Find keras frontend depends
300        # on this property, so we add one force_override to control it.
301        if name not in self.exprs or force_override:
302            self.exprs[name] = expr
303
304    def has_expr(self, name):
305        return True if name in self.exprs else False
306
307    def set_padding(self, paddings):
308        self.paddings = paddings
309        self.in_padding = True
310
311    def clear_padding(self):
312        self.in_padding = False
313
314
315class AttrCvt(object):
316    """Common attribute converter. An AttrConverter instance is a callable:
317    ```
318    attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
319    new_op_name, new_attr = attr_converter(attrs)
320    ```
321
322    Parameters
323    ----------
324    op_name : str or callable
325        If set as str, returned operator name is the str.
326        If set as callable, returned operator is the str returned by calling:
327        `op_name = func(attr)`
328
329    transforms : dict of `new_name, or (new_name, default_value, transform function)`
330        If only a new_name is provided, it's like renaming the attribute name.
331        If default_value if provided, then the attribute is considered as optional.
332        If transform function is provided, the original attribute value is handled
333        by transform function.
334
335    excludes : list
336        A list of excluded attributes that should `NOT` appear.
337        Raise NotImplementedError if occurred.
338
339    disables : list
340        A list of attributes that is disabled in relay. Log warnings.
341
342    ignores : list
343        A list of attributes that is ignored in relay. Debug level logging.
344
345    extras : dict
346        A series of additional attributes should be added anyway to the returned
347        attribute dict.
348
349    custom_check : callable
350        A custom function takes attribute, and return True/False.
351        Raise RuntimeError if not bool(True) returned.
352    """
353    def __init__(self, op_name, transforms=None,
354                 excludes=None, disables=None, ignores=None,
355                 extras=None, custom_check=None):
356        self._op_name = op_name
357        self._transforms = transforms if transforms else {}
358        self._excludes = excludes if excludes else []
359        self._disables = disables if disables else []
360        self._ignores = ignores if ignores else []
361        self._extras = extras if extras else {}
362        self._custom_check = custom_check
363
364    def __call__(self, inputs, attrs, *args):
365        self._ignores.append('_output_shapes')
366        self._ignores.append('_input_shapes')
367        self._ignores.append('T')
368        self._ignores.append('use_cudnn_on_gpu')
369        self._ignores.append('_node_name')
370        self._ignores.append('is_training')
371        self._ignores.append('_target_layout')
372
373        # apply custom check
374        if self._custom_check:
375            func, msg = self._custom_check
376            if not func(attrs):
377                raise RuntimeError("Check failed: {}".format(msg))
378        # get new op_name
379        if isinstance(self._op_name, str):
380            op_name = self._op_name
381        else:
382            assert callable(self._op_name), "op_name can either be string or callable"
383            op_name = self._op_name(attrs)
384
385        # ignore 'tvm_custom' always
386        self._ignores.append('tvm_custom')
387
388        # convert attributes
389        new_attrs = {}
390        for k in attrs.keys():
391            if k in self._excludes:
392                raise NotImplementedError('Attribute %s in operator %s is not' +
393                                          ' supported.', k, op_name)
394            elif k in self._disables:
395                logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name)
396            elif k in self._ignores:
397                if k != 'tvm_custom':
398                    logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name)
399            elif k in self._transforms:
400                new_name, defaults, transform = self._parse_default(self._transforms[k])
401                if defaults is None:
402                    new_attr = self._required_attr(attrs, k)
403                else:
404                    new_attr = attrs.get(k, None)
405                if new_attr is None:
406                    new_attrs[new_name] = defaults
407                else:
408                    new_attrs[new_name] = transform(new_attr)
409            else:
410                # copy
411                new_attrs[k] = attrs[k]
412        # add extras
413        new_attrs.update(self._extras)
414        return get_relay_op(op_name)(*inputs, **new_attrs)
415
416    def _parse_default(self, target):
417        """Helper function to parse default values."""
418        if not isinstance(target, (list, tuple)):
419            k, v, t = target, None, lambda x: x
420        elif len(target) == 1:
421            k, v, t = target[0], None, lambda x: x
422        elif len(target) == 2:
423            k, v, t = target[0], target[1], lambda x: x
424        elif len(target) > 2:
425            k, v, t = target[0], target[1], target[2]
426        else:
427            k = None  # should raise
428        if not isinstance(k, str):
429            msg = "{} is not a valid target, (name, default) expected.".format(target)
430            raise ValueError(msg)
431        return k, v, t
432
433    def _parse_bool(self, value):
434        """Helper function to parse default boolean values."""
435        if isinstance(value, str):
436            return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
437        return bool(value)
438
439    def _required_attr(self, attr, key):
440        """Wrapper for getting required attributes."""
441        assert isinstance(attr, dict)
442        if key not in attr:
443            raise AttributeError("Required attribute {} not found.".format(key))
444        return attr[key]
445
446
447def get_name(node):
448    name = ''
449    if hasattr(node, "name_hint"):
450        name = node.name_hint
451    return name
452
453
454def infer_type(node, mod=None):
455    """A method to infer the type of an intermediate node in the relay graph."""
456    new_mod = _module.Module.from_expr(node)
457    if mod is not None:
458        new_mod.update(mod)
459    new_mod = _transform.InferType()(new_mod)
460    entry = new_mod["main"]
461    return entry if isinstance(node, _expr.Function) else entry.body
462
463def infer_shape(inputs, mod=None):
464    """A method to get the output type of an intermediate node in the graph."""
465    out_type = infer_type(inputs, mod=mod)
466    checked_type = out_type.checked_type
467    if hasattr(checked_type, 'shape'):
468        # Regular operator that outputs tensors
469        return get_const_tuple(out_type.checked_type.shape)
470    # The return type is not a tensor, for example List
471    return checked_type
472
473def infer_channels(inputs, transpose=False):
474    """A hack for getting 'channels' or 'units' since caffe2 does not provide
475    these attributes. We check the shape of weights provided to get the number.
476    """
477    out_type = infer_type(inputs)
478    out_shapes = [get_const_tuple(out_type.checked_type.shape)]
479    channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
480    return channels
481
482
483def infer_value(input_val, params):
484    """A hack for getting the value of an expression by evaluating a
485    portion of the relay graph. This is often needed for functions that
486    whose output shape depends on the value of a tensor.
487    """
488    from tvm.contrib import graph_runtime
489    # Check that all free variables have associated parameters.
490    assert all(var.name_hint in params.keys() for var in analysis.free_vars(
491        input_val)), "All inputs to infer must be available in params."
492    func = _expr.Function(analysis.free_vars(input_val), input_val)
493    with tvm.relay.build_config(opt_level=0):
494        graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
495    ctx = tvm.cpu(0)
496    m = graph_runtime.create(graph, lib, ctx)
497    m.set_input(**params)
498    m.run()
499    return m.get_output(0)
500
501
502def infer_value_simulated(input_val, params):
503    """Extention to infer_value that can be used when some input
504    values are missing. This function creates dummy inputs with the same
505    shape and random values then calls infer_value. This is helpful when
506    implementing certain onnx operators where we need to evaluate the graph
507    to determine a static shape.
508    """
509    fake_params = []
510    # Add a fake copy of all missing params.
511    for free_param in analysis.free_vars(input_val):
512        if free_param.name_hint not in params:
513            fp_dtype = free_param.type_annotation.dtype
514            fp_shape = [s.value for s in free_param.type_annotation.shape]
515            fake_params.append(free_param)
516            params[free_param.name_hint] = tvm.nd.array(
517                np.random.rand(*fp_shape).astype(fp_dtype)
518            )
519    # Now infer the value.
520    output_value = infer_value(input_val, params)
521    # Clean fake params out of param dictionary.
522    for fake_p in fake_params:
523        params.pop(fake_p.name_hint, None)
524    return output_value
525
526
527def new_var(name_hint,
528            type_annotation=None,
529            shape=None,
530            dtype="float32"):
531    return _expr.var(name_hint, type_annotation, shape, dtype)
532
533
534class Renamer(object):
535    """A simply renamer for operators.
536
537    Parameters
538    ----------
539    new_name : str
540        The new name for the operator
541    """
542    def __init__(self, new_name):
543        self._new_name = new_name
544
545    def __call__(self, inputs, attrs, *args):
546        if 'tvm_custom' in attrs:
547            attrs.pop('tvm_custom')
548        return get_relay_op(self._new_name)(*inputs, **attrs)
549