1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Shared functions and classes for frontends."""
18from __future__ import absolute_import as _abs
19import logging
20from nnvm import sym as _sym
21from .._base import string_types
22
23def get_nnvm_op(op_name):
24    op = getattr(_sym, op_name)
25    if not op:
26        raise OpNotImplemented(
27            'Operator {} is not supported.'.format(op))
28    return op
29
30def required_attr(attr, key, op_name):
31    assert isinstance(attr, dict)
32    if key not in attr:
33        raise OpAttributeRequired(
34            'Required attribute {} not found in operator {}'.format(key, op_name))
35    return attr[key]
36
37def parse_tshape(tshape):
38    """Parse tshape in string."""
39    return [int(x.strip()) for x in tshape.strip('()').split(',')]
40
41def parse_bool_str(attr, key, default='False'):
42    """Parse bool string to boolean."""
43    return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes']
44
45class Renamer(object):
46    """A simply renamer for operators.
47
48    Parameters
49    ----------
50    new_name : str
51        The new name for the operator
52    """
53    def __init__(self, new_name):
54        self._new_name = new_name
55
56    def __call__(self, inputs, attrs, *args):
57        return get_nnvm_op(self._new_name)(*inputs, **attrs)
58
59
60class AttrConverter(object):
61    """Common attribute converter. An AttrConverter instance is a callable:
62    ```
63    attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)})
64    new_op_name, new_attr = attr_converter(attrs)
65    ```
66
67    Parameters
68    ----------
69    op_name : str or callable
70        If set as str, returned operator name is the str.
71        If set as callable, returned operator is the str returned by calling:
72        `op_name = func(attr)`
73    transforms : dict of `new_name, or (new_name, default_value, transform function)`
74        If only a new_name is provided, it's like renaming the attribute name.
75        If default_value if provided, then the attribute is considered as optional.
76        If transform function is provided, the original attribute value is handled
77        by transform function.
78    excludes : list
79        A list of excluded attributes that should `NOT` appear.
80        Raise NotImplementedError if occurred.
81    disables : list
82        A list of attributes that is disabled in nnvm. Log warnings.
83    ignores : list
84        A list of attributes that is ignored in nnvm. Debug level logging.
85    extras : dict
86        A series of additional attributes should be added anyway to the returned
87        attribute dict.
88    custom_check : callable
89        A custom function takes attribute, and return True/False.
90        Raise RuntimeError if not bool(True) returned.
91    """
92    def __init__(self, op_name, transforms=None,
93                 excludes=None, disables=None, ignores=None,
94                 extras=None, custom_check=None):
95        self._op_name = op_name
96        self._transforms = transforms if transforms else {}
97        self._excludes = excludes if excludes else []
98        self._disables = disables if disables else []
99        self._ignores = ignores if ignores else []
100        self._extras = extras if extras else {}
101        self._custom_check = custom_check
102
103    def __call__(self, inputs, attrs, *args):
104        # apply custom check
105        if self._custom_check:
106            func, msg = self._custom_check
107            if not func(attrs):
108                raise RuntimeError("Check failed: {}".format(msg))
109        # get new op_name
110        if isinstance(self._op_name, string_types):
111            op_name = self._op_name
112        else:
113            assert callable(self._op_name), "op_name can either be string or callable"
114            op_name = self._op_name(attrs)
115        # convert attributes
116        new_attrs = {}
117        for k in attrs.keys():
118            if k in self._excludes:
119                raise NotImplementedError("Attribute {} not supported yet.".format(k))
120            elif k in self._disables:
121                logging.warning("Attribute %s is disabled in nnvm.sym.%s", k, op_name)
122            elif k in self._ignores:
123                logging.debug("Attribute %s is ignored in nnvm.sym.%s", k, op_name)
124            elif k in self._transforms:
125                new_name, defaults, transform = self._parse_default(self._transforms[k])
126                if defaults is None:
127                    new_attr = self._required_attr(attrs, k)
128                else:
129                    new_attr = attrs.get(k, None)
130                if new_attr is None:
131                    new_attrs[new_name] = defaults
132                else:
133                    new_attrs[new_name] = transform(new_attr)
134            else:
135                # copy
136                new_attrs[k] = attrs[k]
137        # add extras
138        new_attrs.update(self._extras)
139        return get_nnvm_op(op_name)(*inputs, **new_attrs)
140
141    def _parse_default(self, target):
142        """Helper function to parse default values."""
143        if not isinstance(target, (list, tuple)):
144            k, v, t = target, None, lambda x: x
145        elif len(target) == 1:
146            k, v, t = target[0], None, lambda x: x
147        elif len(target) == 2:
148            k, v, t = target[0], target[1], lambda x: x
149        elif len(target) > 2:
150            k, v, t = target[0], target[1], target[2]
151        else:
152            k = None  # should raise
153        if not isinstance(k, string_types):
154            msg = "{} is not a valid target, (name, default) expected.".format(target)
155            raise ValueError(msg)
156        return k, v, t
157
158    def _parse_bool(self, value):
159        """Helper function to parse default boolean values."""
160        if isinstance(value, string_types):
161            return value.strip().lower() in ['true', '1', 't', 'y', 'yes']
162        return bool(value)
163
164    def _required_attr(self, attr, key):
165        """Wrapper for getting required attributes."""
166        assert isinstance(attr, dict)
167        if key not in attr:
168            raise AttributeError("Required attribute {} not found.".format(key))
169        return attr[key]
170
171
172class SymbolTable(object):
173    """Table storing symbols by names."""
174    def __init__(self):
175        self.vars = {}
176        self.params = {}
177        self.const_ctr = 1
178        self.in_padding = False
179        self.paddings = [0, 0]
180
181    def new_const(self, value):
182        name = "_param_%d" % (self.const_ctr)
183        self.const_ctr += 1
184        self.params[name] = value
185        self.vars[name] = _sym.Variable(name=name)
186        return self.vars[name]
187
188    def get_var(self, name, must_contain=True):
189        if must_contain:
190            assert name in self.vars
191        if name not in self.vars:
192            self.vars[name] = _sym.Variable(name=name)
193        return self.vars[name]
194
195    def set_var(self, name, sym):
196        assert isinstance(sym, _sym.Symbol)
197        self.vars[name] = sym
198
199    def set_padding(self, paddings):
200        self.paddings = paddings
201        self.in_padding = True
202
203    def clear_padding(self):
204        self.in_padding = False
205