1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=broad-except 18"""Common utilities""" 19from __future__ import absolute_import as _abs 20import logging 21import numpy as np 22 23import tvm 24from tvm.ir import IRModule 25from tvm.topi.util import get_const_tuple 26 27from .. import expr as _expr 28from .. import function as _function 29from .. import transform as _transform 30from .. import op as _op 31from .. import analysis 32 33 34class RequiredAttr(object): 35 """Dummpy class to represent required attr""" 36 37 38class StrAttrsDict(object): 39 """Helper class to parse attrs stored as Dict[str, str]. 40 41 Parameters 42 ---------- 43 attrs : Dict[str, str] 44 The attributes to be used. 45 """ 46 47 def __init__(self, attrs): 48 self.attrs = attrs 49 50 def has_attr(self, key): 51 """Checks if a attribute is present in the map. 52 53 Parameters 54 ---------- 55 key : str 56 The attribute key 57 58 Returns 59 ------- 60 bool : True if the key is present in the attributes else false. 61 """ 62 return key in self.attrs 63 64 def get_float(self, key, default=RequiredAttr()): 65 """Get float attribute 66 67 Parameters 68 ---------- 69 key : str 70 The attribute key 71 72 default : float 73 The default value. 74 75 Returns 76 ------- 77 value : The result 78 """ 79 if key in self.attrs: 80 return float(self.attrs[key]) 81 if isinstance(default, RequiredAttr): 82 raise AttributeError("Required attribute {} not found.".format(key)) 83 return default 84 85 def get_int(self, key, default=RequiredAttr()): 86 """Get int attribute 87 88 Parameters 89 ---------- 90 key : str 91 The attribute key 92 93 default : float 94 The default value. 95 96 Returns 97 ------- 98 value : The result 99 """ 100 if key in self.attrs: 101 val = self.attrs[key] 102 if val == "None": 103 return None 104 return int(val) 105 if isinstance(default, RequiredAttr): 106 raise AttributeError("Required attribute {} not found.".format(key)) 107 return default 108 109 def get_str(self, key, default=RequiredAttr()): 110 """Get str attribute 111 112 Parameters 113 ---------- 114 key : str 115 The attribute key 116 117 default : float 118 The default value. 119 120 Returns 121 ------- 122 value : The result 123 """ 124 if key in self.attrs: 125 return self.attrs[key] 126 if isinstance(default, RequiredAttr): 127 raise AttributeError("Required attribute {} not found.".format(key)) 128 return default 129 130 def get_int_tuple(self, key, default=RequiredAttr()): 131 """Get int tuple attribute 132 133 Parameters 134 ---------- 135 key : str 136 The attribute key 137 138 default : float 139 The default value. 140 141 Returns 142 ------- 143 value : The result 144 """ 145 if key in self.attrs: 146 tshape = self.attrs[key] 147 return tuple( 148 int(x) if x.strip("- ").isdigit() else None 149 for x in tshape.strip("()[]").split(",") 150 if x 151 ) 152 if isinstance(default, RequiredAttr): 153 raise AttributeError("Required attribute {} not found.".format(key)) 154 return default 155 156 def get_float_tuple(self, key, default=RequiredAttr()): 157 """Get float tuple attribute 158 159 Parameters 160 ---------- 161 key : str 162 The attribute key 163 164 default : float 165 The default value. 166 167 Returns 168 ------- 169 value : The result 170 """ 171 172 if key in self.attrs: 173 tshape = self.attrs[key] 174 return tuple(float(x.strip()) for x in tshape.strip("()[]").split(",")) 175 if isinstance(default, RequiredAttr): 176 raise AttributeError("Required attribute {} not found.".format(key)) 177 return default 178 179 def get_tuple_tuple_int(self, key, default=RequiredAttr()): 180 """Get int list attribute 181 182 Parameters 183 ---------- 184 key : str 185 The attribute key 186 187 default : float 188 The default value. 189 190 Returns 191 ------- 192 value : The result 193 """ 194 if key in self.attrs: 195 value = self.attrs[key] 196 seq = [] 197 for tup in value.strip("()").split("),"): 198 tup = tup.strip("[]()") 199 els = [int(x.strip("( ")) for x in tup.split(",")] 200 seq.append(tuple(els)) 201 202 return tuple(seq) 203 204 if isinstance(default, RequiredAttr): 205 raise AttributeError("Required attribute {} not found.".format(key)) 206 return default 207 208 def get_int_list(self, key, default=RequiredAttr()): 209 """Get int list attribute 210 211 Parameters 212 ---------- 213 key : str 214 The attribute key 215 216 default : float 217 The default value. 218 219 Returns 220 ------- 221 value : The result 222 """ 223 if key in self.attrs: 224 tshape = self.attrs[key] 225 return tuple(int(x.strip()) for x in tshape.strip("[]()").split(",")) 226 if isinstance(default, RequiredAttr): 227 raise AttributeError("Required attribute {} not found.".format(key)) 228 return default 229 230 def get_bool(self, key, default=RequiredAttr()): 231 """Get bool tuple attribute 232 233 Parameters 234 ---------- 235 key : str 236 The attribute key 237 238 default : float 239 The default value. 240 241 Returns 242 ------- 243 value : The result 244 """ 245 if key in self.attrs: 246 val = self.attrs[key] 247 return val.strip().lower() in ["true", "1", "t", "y", "yes"] 248 if isinstance(default, RequiredAttr): 249 raise AttributeError("Required attribute {} not found.".format(key)) 250 return default 251 252 253def get_relay_op(op_name): 254 """Get the callable function from Relay based on operator name. 255 Parameters 256 ---------- 257 op_name : str 258 The Relay operator name. 259 """ 260 if "." in op_name: 261 # explicit hierachical modules 262 op = _op 263 try: 264 for opn in op_name.split("."): 265 op = getattr(op, opn) 266 except AttributeError: 267 op = None 268 else: 269 # try search op in various modules 270 for candidate in (_op, _op.nn, _op.image, _op.vision, _op.contrib): 271 op = getattr(candidate, op_name, None) 272 if op is not None: 273 break 274 if not op: 275 raise tvm.error.OpNotImplemented("Unable to map op_name {} to relay".format(op_name)) 276 return op 277 278 279class ExprTable(object): 280 """Table storing Relay expressions by names.""" 281 282 def __init__(self): 283 self.exprs = {} 284 self.params = {} 285 self.const_ctr = 1 286 self.in_padding = False 287 288 def new_const(self, value, shape=None, dtype="float32"): 289 name = "_param_%d" % (self.const_ctr) 290 if hasattr(value, "shape"): 291 shape = value.shape 292 self.const_ctr += 1 293 self.params[name] = value 294 self.exprs[name] = _expr.var(name_hint=name, shape=shape, dtype=dtype) 295 return self.exprs[name] 296 297 def get_expr(self, name): 298 return self.exprs[name] 299 300 def set_expr(self, name, expr, force_override=False): 301 assert isinstance(expr, _expr.Expr) 302 # if name exists, we should override the value 303 # otherwise, we can not get like x = func(x) work. 304 # One example is CoreML preprocess, which will override 305 # the same name of input. 306 # However, according to git log, Find keras frontend depends 307 # on this property, so we add one force_override to control it. 308 if name not in self.exprs or force_override: 309 self.exprs[name] = expr 310 311 def has_expr(self, name): 312 return name in self.exprs 313 314 def set_padding(self, paddings): 315 self.paddings = paddings 316 self.in_padding = True 317 318 def clear_padding(self): 319 self.in_padding = False 320 321 322class AttrCvt(object): 323 """Common attribute converter. An AttrConverter instance is a callable: 324 ``` 325 attr_converter = AttrConverter(op_name, transforms={'a':'b', 'c':('d', 1)}) 326 new_op_name, new_attr = attr_converter(attrs) 327 ``` 328 329 Parameters 330 ---------- 331 op_name : str or callable 332 If set as str, returned operator name is the str. 333 If set as callable, returned operator is the str returned by calling: 334 `op_name = func(attr)` 335 336 transforms : dict of `new_name, or (new_name, default_value, transform function)` 337 If only a new_name is provided, it's like renaming the attribute name. 338 If default_value if provided, then the attribute is considered as optional. 339 If transform function is provided, the original attribute value is handled 340 by transform function. 341 342 excludes : list 343 A list of excluded attributes that should `NOT` appear. 344 Raise NotImplementedError if occurred. 345 346 disables : list 347 A list of attributes that is disabled in relay. Log warnings. 348 349 ignores : list 350 A list of attributes that is ignored in relay. Debug level logging. 351 352 extras : dict 353 A series of additional attributes should be added anyway to the returned 354 attribute dict. 355 356 custom_check : callable 357 A custom function takes attribute, and return True/False. 358 Raise RuntimeError if not bool(True) returned. 359 """ 360 361 def __init__( 362 self, 363 op_name, 364 transforms=None, 365 excludes=None, 366 disables=None, 367 ignores=None, 368 extras=None, 369 custom_check=None, 370 ): 371 self._op_name = op_name 372 self._transforms = transforms if transforms else {} 373 self._excludes = excludes if excludes else [] 374 self._disables = disables if disables else [] 375 self._ignores = ignores if ignores else [] 376 self._extras = extras if extras else {} 377 self._custom_check = custom_check 378 379 def __call__(self, inputs, attrs, *args): 380 self._ignores.append("_output_shapes") 381 self._ignores.append("_input_shapes") 382 self._ignores.append("T") 383 self._ignores.append("use_cudnn_on_gpu") 384 self._ignores.append("_node_name") 385 self._ignores.append("is_training") 386 self._ignores.append("_target_layout") 387 388 # apply custom check 389 if self._custom_check: 390 func, msg = self._custom_check 391 if not func(attrs): 392 raise RuntimeError("Check failed: {}".format(msg)) 393 # get new op_name 394 if isinstance(self._op_name, str): 395 op_name = self._op_name 396 else: 397 assert callable(self._op_name), "op_name can either be string or callable" 398 op_name = self._op_name(attrs) 399 400 # ignore 'tvm_custom' always 401 self._ignores.append("tvm_custom") 402 403 # convert attributes 404 new_attrs = {} 405 for k in attrs.keys(): 406 if k in self._excludes: 407 raise NotImplementedError( 408 "Attribute %s in operator %s is not" + " supported.", k, op_name 409 ) 410 if k in self._disables: 411 logging.warning("Attribute %s is disabled in relay.sym.%s", k, op_name) 412 elif k in self._ignores: 413 if k != "tvm_custom": 414 logging.warning("Attribute %s is ignored in relay.sym.%s", k, op_name) 415 elif k in self._transforms: 416 new_name, defaults, transform = self._parse_default(self._transforms[k]) 417 if defaults is None: 418 new_attr = self._required_attr(attrs, k) 419 else: 420 new_attr = attrs.get(k, None) 421 if new_attr is None: 422 new_attrs[new_name] = defaults 423 else: 424 new_attrs[new_name] = transform(new_attr) 425 else: 426 # copy 427 new_attrs[k] = attrs[k] 428 # add extras 429 new_attrs.update(self._extras) 430 return get_relay_op(op_name)(*inputs, **new_attrs) 431 432 def _parse_default(self, target): 433 """Helper function to parse default values.""" 434 if not isinstance(target, (list, tuple)): 435 k, v, t = target, None, lambda x: x 436 elif len(target) == 1: 437 k, v, t = target[0], None, lambda x: x 438 elif len(target) == 2: 439 k, v, t = target[0], target[1], lambda x: x 440 elif len(target) > 2: 441 k, v, t = target[0], target[1], target[2] 442 else: 443 k = None # should raise 444 if not isinstance(k, str): 445 msg = "{} is not a valid target, (name, default) expected.".format(target) 446 raise ValueError(msg) 447 return k, v, t 448 449 def _parse_bool(self, value): 450 """Helper function to parse default boolean values.""" 451 if isinstance(value, str): 452 return value.strip().lower() in ["true", "1", "t", "y", "yes"] 453 return bool(value) 454 455 def _required_attr(self, attr, key): 456 """Wrapper for getting required attributes.""" 457 assert isinstance(attr, dict) 458 if key not in attr: 459 raise AttributeError("Required attribute {} not found.".format(key)) 460 return attr[key] 461 462 463def get_name(node): 464 name = "" 465 if hasattr(node, "name_hint"): 466 name = node.name_hint 467 return name 468 469 470def infer_type(node, mod=None): 471 """A method to infer the type of an intermediate node in the relay graph.""" 472 if isinstance(mod, IRModule): 473 mod["main"] = _function.Function(tvm.relay.analysis.free_vars(node), node) 474 mod = _transform.InferType()(mod) 475 entry = mod["main"] 476 ret = entry.body 477 else: 478 new_mod = IRModule.from_expr(node) 479 if mod is not None: 480 new_mod.update(mod) 481 new_mod = _transform.InferType()(new_mod) 482 entry = new_mod["main"] 483 ret = entry if isinstance(node, _function.Function) else entry.body 484 485 return ret 486 487 488def infer_channels(inputs, transpose=False): 489 """A hack for getting 'channels' or 'units' since caffe2 does not provide 490 these attributes. We check the shape of weights provided to get the number. 491 """ 492 out_type = infer_type(inputs) 493 out_shapes = [get_const_tuple(out_type.checked_type.shape)] 494 channels = out_shapes[0][0] if not transpose else out_shapes[0][1] 495 return channels 496 497 498def infer_shape(inputs, mod=None): 499 """A method to get the output type of an intermediate node in the graph.""" 500 out_type = infer_type(inputs, mod=mod) 501 checked_type = out_type.checked_type 502 if hasattr(checked_type, "shape"): 503 # Regular operator that outputs tensors 504 return get_const_tuple(checked_type.shape) 505 # The return type is not a tensor, for example List 506 return checked_type 507 508 509def infer_value(input_val, params, mod=None): 510 """A hack for getting the value of an expression by evaluating a 511 portion of the relay graph. This is often needed for functions that 512 whose output shape depends on the value of a tensor. 513 """ 514 # Check that all free variables have associated parameters. 515 assert all( 516 var.name_hint in params.keys() for var in analysis.free_vars(input_val) 517 ), "All inputs to infer must be available in params." 518 try: 519 # TODO(kevinthesun): Use VM for all cases. 520 # pylint: disable=import-outside-toplevel 521 from tvm.contrib import graph_runtime 522 523 func = _function.Function(analysis.free_vars(input_val), input_val) 524 with tvm.transform.PassContext(opt_level=0): 525 lib = tvm.relay.build(func, target="llvm", params=params) 526 ctx = tvm.cpu(0) 527 m = graph_runtime.GraphModule(lib["default"](ctx)) 528 m.run() 529 return m.get_output(0) 530 except Exception: 531 if isinstance(mod, IRModule): 532 mod["main"] = _function.Function(analysis.free_vars(input_val), input_val) 533 else: 534 mod = IRModule.from_expr(input_val) 535 exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm") 536 inputs = [] 537 for param in mod["main"].params: 538 inputs.append(params[param.name_hint]) 539 result = exc.evaluate()(*inputs) 540 return result 541 542 543def infer_value_simulated(input_val, params): 544 """Extention to infer_value that can be used when some input 545 values are missing. This function creates dummy inputs with the same 546 shape and random values then calls infer_value. This is helpful when 547 implementing certain onnx operators where we need to evaluate the graph 548 to determine a static shape. 549 """ 550 fake_params = [] 551 # Add a fake copy of all missing params. 552 for free_param in analysis.free_vars(input_val): 553 if free_param.name_hint not in params: 554 fp_dtype = free_param.type_annotation.dtype 555 fp_shape = [s.value for s in free_param.type_annotation.shape] 556 fake_params.append(free_param) 557 params[free_param.name_hint] = tvm.nd.array(np.random.rand(*fp_shape).astype(fp_dtype)) 558 # Now infer the value. 559 output_value = infer_value(input_val, params) 560 # Clean fake params out of param dictionary. 561 for fake_p in fake_params: 562 params.pop(fake_p.name_hint, None) 563 return output_value 564 565 566def try_infer_value(val, on_success=None, on_failure=None): 567 """Try running infer_value on the input val, and if successful, return the inferred value or 568 pass it to on_success callback if provided. Otherwise, run on_failure callback if it is 569 provided, or return the input val as output. In each case, the second return value 570 indicates whether infer_value has succeeded or not. 571 """ 572 try: 573 ret = infer_value(val, {}).asnumpy() 574 if on_success: 575 return on_success(ret), True 576 return ret, True 577 except Exception: 578 if on_failure: 579 return on_failure(), False 580 return val, False 581 582 583def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"): 584 return _expr.var(name_hint, type_annotation, shape, dtype) 585 586 587class Renamer(object): 588 """A simply renamer for operators. 589 590 Parameters 591 ---------- 592 new_name : str 593 The new name for the operator 594 """ 595 596 def __init__(self, new_name): 597 self._new_name = new_name 598 599 def __call__(self, inputs, attrs, *args): 600 if "tvm_custom" in attrs: 601 attrs.pop("tvm_custom") 602 return get_relay_op(self._new_name)(*inputs, **attrs) 603