1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Shared functions and classes for frontends.""" 18from __future__ import absolute_import as _abs 19import logging 20from nnvm import sym as _sym 21from .._base import string_types 22 23def get_nnvm_op(op_name): 24 op = getattr(_sym, op_name) 25 if not op: 26 raise OpNotImplemented( 27 'Operator {} is not supported.'.format(op)) 28 return op 29 30def required_attr(attr, key, op_name): 31 assert isinstance(attr, dict) 32 if key not in attr: 33 raise OpAttributeRequired( 34 'Required attribute {} not found in operator {}'.format(key, op_name)) 35 return attr[key] 36 37def parse_tshape(tshape): 38 """Parse tshape in string.""" 39 return [int(x.strip()) for x in tshape.strip('()').split(',')] 40 41def parse_bool_str(attr, key, default='False'): 42 """Parse bool string to boolean.""" 43 return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes'] 44 45class Renamer(object): 46 """A simply renamer for operators. 47 48 Parameters 49 ---------- 50 new_name : str 51 The new name for the operator 52 """ 53 def __init__(self, new_name): 54 self._new_name = new_name 55 56 def __call__(self, inputs, attrs, *args): 57 return get_nnvm_op(self._new_name)(*inputs, **attrs) 58 59 60class AttrConverter(object): 61 """Common attribute converter. An AttrConverter instance is a callable: 62 ``` 63 attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) 64 new_op_name, new_attr = attr_converter(attrs) 65 ``` 66 67 Parameters 68 ---------- 69 op_name : str or callable 70 If set as str, returned operator name is the str. 71 If set as callable, returned operator is the str returned by calling: 72 `op_name = func(attr)` 73 transforms : dict of `new_name, or (new_name, default_value, transform function)` 74 If only a new_name is provided, it's like renaming the attribute name. 75 If default_value if provided, then the attribute is considered as optional. 76 If transform function is provided, the original attribute value is handled 77 by transform function. 78 excludes : list 79 A list of excluded attributes that should `NOT` appear. 80 Raise NotImplementedError if occurred. 81 disables : list 82 A list of attributes that is disabled in nnvm. Log warnings. 83 ignores : list 84 A list of attributes that is ignored in nnvm. Debug level logging. 85 extras : dict 86 A series of additional attributes should be added anyway to the returned 87 attribute dict. 88 custom_check : callable 89 A custom function takes attribute, and return True/False. 90 Raise RuntimeError if not bool(True) returned. 91 """ 92 def __init__(self, op_name, transforms=None, 93 excludes=None, disables=None, ignores=None, 94 extras=None, custom_check=None): 95 self._op_name = op_name 96 self._transforms = transforms if transforms else {} 97 self._excludes = excludes if excludes else [] 98 self._disables = disables if disables else [] 99 self._ignores = ignores if ignores else [] 100 self._extras = extras if extras else {} 101 self._custom_check = custom_check 102 103 def __call__(self, inputs, attrs, *args): 104 # apply custom check 105 if self._custom_check: 106 func, msg = self._custom_check 107 if not func(attrs): 108 raise RuntimeError("Check failed: {}".format(msg)) 109 # get new op_name 110 if isinstance(self._op_name, string_types): 111 op_name = self._op_name 112 else: 113 assert callable(self._op_name), "op_name can either be string or callable" 114 op_name = self._op_name(attrs) 115 # convert attributes 116 new_attrs = {} 117 for k in attrs.keys(): 118 if k in self._excludes: 119 raise NotImplementedError("Attribute {} not supported yet.".format(k)) 120 elif k in self._disables: 121 logging.warning("Attribute %s is disabled in nnvm.sym.%s", k, op_name) 122 elif k in self._ignores: 123 logging.debug("Attribute %s is ignored in nnvm.sym.%s", k, op_name) 124 elif k in self._transforms: 125 new_name, defaults, transform = self._parse_default(self._transforms[k]) 126 if defaults is None: 127 new_attr = self._required_attr(attrs, k) 128 else: 129 new_attr = attrs.get(k, None) 130 if new_attr is None: 131 new_attrs[new_name] = defaults 132 else: 133 new_attrs[new_name] = transform(new_attr) 134 else: 135 # copy 136 new_attrs[k] = attrs[k] 137 # add extras 138 new_attrs.update(self._extras) 139 return get_nnvm_op(op_name)(*inputs, **new_attrs) 140 141 def _parse_default(self, target): 142 """Helper function to parse default values.""" 143 if not isinstance(target, (list, tuple)): 144 k, v, t = target, None, lambda x: x 145 elif len(target) == 1: 146 k, v, t = target[0], None, lambda x: x 147 elif len(target) == 2: 148 k, v, t = target[0], target[1], lambda x: x 149 elif len(target) > 2: 150 k, v, t = target[0], target[1], target[2] 151 else: 152 k = None # should raise 153 if not isinstance(k, string_types): 154 msg = "{} is not a valid target, (name, default) expected.".format(target) 155 raise ValueError(msg) 156 return k, v, t 157 158 def _parse_bool(self, value): 159 """Helper function to parse default boolean values.""" 160 if isinstance(value, string_types): 161 return value.strip().lower() in ['true', '1', 't', 'y', 'yes'] 162 return bool(value) 163 164 def _required_attr(self, attr, key): 165 """Wrapper for getting required attributes.""" 166 assert isinstance(attr, dict) 167 if key not in attr: 168 raise AttributeError("Required attribute {} not found.".format(key)) 169 return attr[key] 170 171 172class SymbolTable(object): 173 """Table storing symbols by names.""" 174 def __init__(self): 175 self.vars = {} 176 self.params = {} 177 self.const_ctr = 1 178 self.in_padding = False 179 self.paddings = [0, 0] 180 181 def new_const(self, value): 182 name = "_param_%d" % (self.const_ctr) 183 self.const_ctr += 1 184 self.params[name] = value 185 self.vars[name] = _sym.Variable(name=name) 186 return self.vars[name] 187 188 def get_var(self, name, must_contain=True): 189 if must_contain: 190 assert name in self.vars 191 if name not in self.vars: 192 self.vars[name] = _sym.Variable(name=name) 193 return self.vars[name] 194 195 def set_var(self, name, sym): 196 assert isinstance(sym, _sym.Symbol) 197 self.vars[name] = sym 198 199 def set_padding(self, paddings): 200 self.paddings = paddings 201 self.in_padding = True 202 203 def clear_padding(self): 204 self.in_padding = False 205