1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Common utilities""" 18from __future__ import absolute_import as _abs 19import logging 20 21import tvm 22import numpy as np 23from topi.util import get_const_tuple 24from .. import expr as _expr 25from .. import module as _module 26from .. import transform as _transform 27from .. import op as _op 28from .. import analysis 29 30 31class RequiredAttr(object): 32 """Dummpy class to represent required attr""" 33 34 35class StrAttrsDict(object): 36 """Helper class to parse attrs stored as Dict[str, str]. 37 38 Parameters 39 ---------- 40 attrs : Dict[str, str] 41 The attributes to be used. 42 """ 43 def __init__(self, attrs): 44 self.attrs = attrs 45 46 def has_attr(self, key): 47 """Checks if a attribute is present in the map. 48 49 Parameters 50 ---------- 51 key : str 52 The attribute key 53 54 Returns 55 ------- 56 bool : True if the key is present in the attributes else false. 57 """ 58 return key in self.attrs 59 60 def get_float(self, key, default=RequiredAttr()): 61 """Get float attribute 62 63 Parameters 64 ---------- 65 key : str 66 The attribute key 67 68 default : float 69 The default value. 70 71 Returns 72 ------- 73 value : The result 74 """ 75 if key in self.attrs: 76 return float(self.attrs[key]) 77 if isinstance(default, RequiredAttr): 78 raise AttributeError("Required attribute {} not found.".format(key)) 79 return default 80 81 def get_int(self, key, default=RequiredAttr()): 82 """Get int attribute 83 84 Parameters 85 ---------- 86 key : str 87 The attribute key 88 89 default : float 90 The default value. 91 92 Returns 93 ------- 94 value : The result 95 """ 96 if key in self.attrs: 97 val = self.attrs[key] 98 if val == "None": 99 return None 100 return int(val) 101 if isinstance(default, RequiredAttr): 102 raise AttributeError("Required attribute {} not found.".format(key)) 103 return default 104 105 def get_str(self, key, default=RequiredAttr()): 106 """Get str attribute 107 108 Parameters 109 ---------- 110 key : str 111 The attribute key 112 113 default : float 114 The default value. 115 116 Returns 117 ------- 118 value : The result 119 """ 120 if key in self.attrs: 121 return self.attrs[key] 122 if isinstance(default, RequiredAttr): 123 raise AttributeError("Required attribute {} not found.".format(key)) 124 return default 125 126 def get_int_tuple(self, key, default=RequiredAttr()): 127 """Get int tuple attribute 128 129 Parameters 130 ---------- 131 key : str 132 The attribute key 133 134 default : float 135 The default value. 136 137 Returns 138 ------- 139 value : The result 140 """ 141 if key in self.attrs: 142 tshape = self.attrs[key] 143 return tuple(int(x) if x.strip("- ").isdigit() else None 144 for x in tshape.strip('()[]').split(',') if x) 145 if isinstance(default, RequiredAttr): 146 raise AttributeError("Required attribute {} not found.".format(key)) 147 return default 148 149 def get_float_tuple(self, key, default=RequiredAttr()): 150 """Get float tuple attribute 151 152 Parameters 153 ---------- 154 key : str 155 The attribute key 156 157 default : float 158 The default value. 159 160 Returns 161 ------- 162 value : The result 163 """ 164 165 if key in self.attrs: 166 tshape = self.attrs[key] 167 return tuple(float(x.strip()) for x in 168 tshape.strip('()[]').split(',')) 169 if isinstance(default, RequiredAttr): 170 raise AttributeError("Required attribute {} not found.".format(key)) 171 return default 172 173 def get_tuple_tuple_int(self, key, default=RequiredAttr()): 174 """Get int list attribute 175 176 Parameters 177 ---------- 178 key : str 179 The attribute key 180 181 default : float 182 The default value. 183 184 Returns 185 ------- 186 value : The result 187 """ 188 if key in self.attrs: 189 value = self.attrs[key] 190 seq = [] 191 for tup in value.strip('()').split('),'): 192 tup = tup.strip('[]()') 193 els = [int(x.strip('( ')) for x in tup.split(',')] 194 seq.append(tuple(els)) 195 196 return tuple(seq) 197 198 if isinstance(default, RequiredAttr): 199 raise AttributeError("Required attribute {} not found.".format(key)) 200 return default 201 202 def get_int_list(self, key, default=RequiredAttr()): 203 """Get int list attribute 204 205 Parameters 206 ---------- 207 key : str 208 The attribute key 209 210 default : float 211 The default value. 212 213 Returns 214 ------- 215 value : The result 216 """ 217 if key in self.attrs: 218 tshape = self.attrs[key] 219 return tuple(int(x.strip()) for x in tshape.strip('[]()').split(',')) 220 if isinstance(default, RequiredAttr): 221 raise AttributeError("Required attribute {} not found.".format(key)) 222 return default 223 224 def get_bool(self, key, default=RequiredAttr()): 225 """Get bool tuple attribute 226 227 Parameters 228 ---------- 229 key : str 230 The attribute key 231 232 default : float 233 The default value. 234 235 Returns 236 ------- 237 value : The result 238 """ 239 if key in self.attrs: 240 val = self.attrs[key] 241 return val.strip().lower() in ['true', '1', 't', 'y', 'yes'] 242 if isinstance(default, RequiredAttr): 243 raise AttributeError("Required attribute {} not found.".format(key)) 244 return default 245 246 247def get_relay_op(op_name): 248 """Get the callable function from Relay based on operator name. 249 Parameters 250 ---------- 251 op_name : str 252 The Relay operator name. 253 """ 254 if '.' in op_name: 255 # explicit hierachical modules 256 op = _op 257 try: 258 for opn in op_name.split('.'): 259 op = getattr(op, opn) 260 except AttributeError: 261 op = None 262 else: 263 # try search op in various modules 264 for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib): 265 op = getattr(candidate, op_name, None) 266 if op is not None: 267 break 268 if not op: 269 raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name)) 270 return op 271 272 273class ExprTable(object): 274 """Table storing Relay expressions by names.""" 275 def __init__(self): 276 self.exprs = {} 277 self.params = {} 278 self.const_ctr = 1 279 self.in_padding = False 280 281 def new_const(self, value, shape=None, dtype="float32"): 282 name = "_param_%d" % (self.const_ctr) 283 if hasattr(value, "shape"): 284 shape = value.shape 285 self.const_ctr += 1 286 self.params[name] = value 287 self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype) 288 return self.exprs[name] 289 290 def get_expr(self, name): 291 return self.exprs[name] 292 293 def set_expr(self, name, expr, force_override=False): 294 assert isinstance(expr, _expr.Expr) 295 # if name exists, we should override the value 296 # otherwise, we can not get like x = func(x) work. 297 # One example is CoreML preprocess, which will override 298 # the same name of input. 299 # However, according to git log, Find keras frontend depends 300 # on this property, so we add one force_override to control it. 301 if name not in self.exprs or force_override: 302 self.exprs[name] = expr 303 304 def has_expr(self, name): 305 return True if name in self.exprs else False 306 307 def set_padding(self, paddings): 308 self.paddings = paddings 309 self.in_padding = True 310 311 def clear_padding(self): 312 self.in_padding = False 313 314 315class AttrCvt(object): 316 """Common attribute converter. An AttrConverter instance is a callable: 317 ``` 318 attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) 319 new_op_name, new_attr = attr_converter(attrs) 320 ``` 321 322 Parameters 323 ---------- 324 op_name : str or callable 325 If set as str, returned operator name is the str. 326 If set as callable, returned operator is the str returned by calling: 327 `op_name = func(attr)` 328 329 transforms : dict of `new_name, or (new_name, default_value, transform function)` 330 If only a new_name is provided, it's like renaming the attribute name. 331 If default_value if provided, then the attribute is considered as optional. 332 If transform function is provided, the original attribute value is handled 333 by transform function. 334 335 excludes : list 336 A list of excluded attributes that should `NOT` appear. 337 Raise NotImplementedError if occurred. 338 339 disables : list 340 A list of attributes that is disabled in relay. Log warnings. 341 342 ignores : list 343 A list of attributes that is ignored in relay. Debug level logging. 344 345 extras : dict 346 A series of additional attributes should be added anyway to the returned 347 attribute dict. 348 349 custom_check : callable 350 A custom function takes attribute, and return True/False. 351 Raise RuntimeError if not bool(True) returned. 352 """ 353 def __init__(self, op_name, transforms=None, 354 excludes=None, disables=None, ignores=None, 355 extras=None, custom_check=None): 356 self._op_name = op_name 357 self._transforms = transforms if transforms else {} 358 self._excludes = excludes if excludes else [] 359 self._disables = disables if disables else [] 360 self._ignores = ignores if ignores else [] 361 self._extras = extras if extras else {} 362 self._custom_check = custom_check 363 364 def __call__(self, inputs, attrs, *args): 365 self._ignores.append('_output_shapes') 366 self._ignores.append('_input_shapes') 367 self._ignores.append('T') 368 self._ignores.append('use_cudnn_on_gpu') 369 self._ignores.append('_node_name') 370 self._ignores.append('is_training') 371 self._ignores.append('_target_layout') 372 373 # apply custom check 374 if self._custom_check: 375 func, msg = self._custom_check 376 if not func(attrs): 377 raise RuntimeError("Check failed: {}".format(msg)) 378 # get new op_name 379 if isinstance(self._op_name, str): 380 op_name = self._op_name 381 else: 382 assert callable(self._op_name), "op_name can either be string or callable" 383 op_name = self._op_name(attrs) 384 385 # ignore 'tvm_custom' always 386 self._ignores.append('tvm_custom') 387 388 # convert attributes 389 new_attrs = {} 390 for k in attrs.keys(): 391 if k in self._excludes: 392 raise NotImplementedError('Attribute %s in operator %s is not' + 393 ' supported.', k, op_name) 394 elif k in self._disables: 395 logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) 396 elif k in self._ignores: 397 if k != 'tvm_custom': 398 logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name) 399 elif k in self._transforms: 400 new_name, defaults, transform = self._parse_default(self._transforms[k]) 401 if defaults is None: 402 new_attr = self._required_attr(attrs, k) 403 else: 404 new_attr = attrs.get(k, None) 405 if new_attr is None: 406 new_attrs[new_name] = defaults 407 else: 408 new_attrs[new_name] = transform(new_attr) 409 else: 410 # copy 411 new_attrs[k] = attrs[k] 412 # add extras 413 new_attrs.update(self._extras) 414 return get_relay_op(op_name)(*inputs, **new_attrs) 415 416 def _parse_default(self, target): 417 """Helper function to parse default values.""" 418 if not isinstance(target, (list, tuple)): 419 k, v, t = target, None, lambda x: x 420 elif len(target) == 1: 421 k, v, t = target[0], None, lambda x: x 422 elif len(target) == 2: 423 k, v, t = target[0], target[1], lambda x: x 424 elif len(target) > 2: 425 k, v, t = target[0], target[1], target[2] 426 else: 427 k = None # should raise 428 if not isinstance(k, str): 429 msg = "{} is not a valid target, (name, default) expected.".format(target) 430 raise ValueError(msg) 431 return k, v, t 432 433 def _parse_bool(self, value): 434 """Helper function to parse default boolean values.""" 435 if isinstance(value, str): 436 return value.strip().lower() in ['true', '1', 't', 'y', 'yes'] 437 return bool(value) 438 439 def _required_attr(self, attr, key): 440 """Wrapper for getting required attributes.""" 441 assert isinstance(attr, dict) 442 if key not in attr: 443 raise AttributeError("Required attribute {} not found.".format(key)) 444 return attr[key] 445 446 447def get_name(node): 448 name = '' 449 if hasattr(node, "name_hint"): 450 name = node.name_hint 451 return name 452 453 454def infer_type(node, mod=None): 455 """A method to infer the type of an intermediate node in the relay graph.""" 456 new_mod = _module.Module.from_expr(node) 457 if mod is not None: 458 new_mod.update(mod) 459 new_mod = _transform.InferType()(new_mod) 460 entry = new_mod["main"] 461 return entry if isinstance(node, _expr.Function) else entry.body 462 463def infer_shape(inputs, mod=None): 464 """A method to get the output type of an intermediate node in the graph.""" 465 out_type = infer_type(inputs, mod=mod) 466 checked_type = out_type.checked_type 467 if hasattr(checked_type, 'shape'): 468 # Regular operator that outputs tensors 469 return get_const_tuple(out_type.checked_type.shape) 470 # The return type is not a tensor, for example List 471 return checked_type 472 473def infer_channels(inputs, transpose=False): 474 """A hack for getting 'channels' or 'units' since caffe2 does not provide 475 these attributes. We check the shape of weights provided to get the number. 476 """ 477 out_type = infer_type(inputs) 478 out_shapes = [get_const_tuple(out_type.checked_type.shape)] 479 channels = out_shapes[0][0] if not transpose else out_shapes[0][1] 480 return channels 481 482 483def infer_value(input_val, params): 484 """A hack for getting the value of an expression by evaluating a 485 portion of the relay graph. This is often needed for functions that 486 whose output shape depends on the value of a tensor. 487 """ 488 from tvm.contrib import graph_runtime 489 # Check that all free variables have associated parameters. 490 assert all(var.name_hint in params.keys() for var in analysis.free_vars( 491 input_val)), "All inputs to infer must be available in params." 492 func = _expr.Function(analysis.free_vars(input_val), input_val) 493 with tvm.relay.build_config(opt_level=0): 494 graph, lib, params = tvm.relay.build(func, target="llvm", params=params) 495 ctx = tvm.cpu(0) 496 m = graph_runtime.create(graph, lib, ctx) 497 m.set_input(**params) 498 m.run() 499 return m.get_output(0) 500 501 502def infer_value_simulated(input_val, params): 503 """Extention to infer_value that can be used when some input 504 values are missing. This function creates dummy inputs with the same 505 shape and random values then calls infer_value. This is helpful when 506 implementing certain onnx operators where we need to evaluate the graph 507 to determine a static shape. 508 """ 509 fake_params = [] 510 # Add a fake copy of all missing params. 511 for free_param in analysis.free_vars(input_val): 512 if free_param.name_hint not in params: 513 fp_dtype = free_param.type_annotation.dtype 514 fp_shape = [s.value for s in free_param.type_annotation.shape] 515 fake_params.append(free_param) 516 params[free_param.name_hint] = tvm.nd.array( 517 np.random.rand(*fp_shape).astype(fp_dtype) 518 ) 519 # Now infer the value. 520 output_value = infer_value(input_val, params) 521 # Clean fake params out of param dictionary. 522 for fake_p in fake_params: 523 params.pop(fake_p.name_hint, None) 524 return output_value 525 526 527def new_var(name_hint, 528 type_annotation=None, 529 shape=None, 530 dtype="float32"): 531 return _expr.var(name_hint, type_annotation, shape, dtype) 532 533 534class Renamer(object): 535 """A simply renamer for operators. 536 537 Parameters 538 ---------- 539 new_name : str 540 The new name for the operator 541 """ 542 def __init__(self, new_name): 543 self._new_name = new_name 544 545 def __call__(self, inputs, attrs, *args): 546 if 'tvm_custom' in attrs: 547 attrs.pop('tvm_custom') 548 return get_relay_op(self._new_name)(*inputs, **attrs) 549