1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=unused-variable,not-callable 18"""Definition of task function. 19 20Task can be constructed from tuple of func, args, and kwargs. 21func is a state-less function, or a string that 22registers the standard task. 23""" 24import numpy as np 25 26from tvm.target import Target 27from tvm import runtime 28from tvm.ir import container 29from tvm.tir import expr 30from tvm.te import tensor, placeholder 31 32 33from ..util import get_const_int, get_const_tuple 34from .dispatcher import DispatchContext, ApplyConfig 35from .space import ConfigSpace 36 37 38def _raise_error(*args, **kwargs): # pylint: disable=unused-argument 39 raise RuntimeError( 40 "The function of this task is not found. Possibly the function " 41 "of this task is registered in another python file " 42 "which is not imported in this run" 43 ) 44 45 46def serialize_args(args): 47 """serialize arguments of a topi function to a hashable tuple. 48 49 Parameters 50 ---------- 51 args: list of hashable or Tensor 52 """ 53 54 def _encode(x): 55 if isinstance(x, tensor.Tensor): 56 return ("TENSOR", get_const_tuple(x.shape), x.dtype) 57 if isinstance(x, (tuple, list, container.Array)): 58 return tuple([_encode(a) for a in x]) 59 if isinstance(x, (str, int, float, np.int, np.float, expr.Var, expr.Any)): 60 return x 61 if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): 62 return x.value 63 if isinstance(x, runtime.container.String): 64 return str(x) 65 if x is None: 66 return None 67 raise RuntimeError( 68 'Do not support type "%s" in argument. Consider to use' 69 "primitive types or tvm.tir.Var only" % type(x) 70 ) 71 72 ret = [] 73 for t in args: 74 ret.append(_encode(t)) 75 return tuple(ret) 76 77 78def deserialize_args(args): 79 """The inverse function of :code:`serialize_args`. 80 81 Parameters 82 ---------- 83 args: list of hashable or Tensor 84 """ 85 ret = [] 86 for t in args: 87 if isinstance(t, tuple) and t[0] == "TENSOR": 88 ret.append(placeholder(shape=t[1], dtype=t[2])) 89 else: 90 ret.append(t) 91 return ret 92 93 94def args_to_workload(args, task_name=None): 95 """Convert argument list to hashable workload tuple. 96 This function will convert list to tuple, tvm node to python value and 97 flatten te.tensor.Tensor to a tuple 98 99 Parameters 100 ---------- 101 task_name : str 102 The AutoTVM task name 103 104 args : list of args 105 The arguments to the function 106 107 Returns 108 ------- 109 ret: hashable 110 The hashable value 111 """ 112 return (task_name,) + serialize_args(args) if task_name is not None else serialize_args(args) 113 114 115class Task(object): 116 """A Tunable Task 117 118 Parameters 119 ---------- 120 name: str 121 The name of the task. 122 args: Tuple 123 Positional argument of func 124 """ 125 126 def __init__(self, name, args): 127 self.name = name 128 self.args = args 129 self.kwargs = {} # currently unused 130 131 # init null config space 132 self.config_space = None 133 self.func = TASK_TABLE.get(name, _raise_error) 134 135 # auxiliary info, available after `init_space` is called 136 self.flop = None 137 self.target = None 138 self.target_host = None 139 140 @property 141 def workload(self): 142 return (self.name,) + serialize_args(self.args) 143 144 def instantiate(self, config): 145 """Instantiate this task function (template) with a config. 146 Returns corresponding schedule. 147 148 Parameters 149 ---------- 150 config: template.ConfigEntity 151 parameter config for this template 152 153 Returns 154 ------- 155 sch: tvm.te.schedule.Schedule 156 The tvm schedule 157 arg_bufs: Array of te.tensor.Tensor 158 The input/output buffers 159 """ 160 config.flop = 0 161 with ApplyConfig(config): 162 sch, arg_bufs = self.func(*self.args, **self.kwargs) 163 if not self.flop: 164 config.flop = config.flop or compute_flop(sch) 165 self.flop = config.flop 166 return sch, arg_bufs 167 168 def __getstate__(self): 169 # custom pickle implementation is required for 170 # some unpickable local task functions. 171 # So we only pickle the name of the function 172 # and restore the function by name when unpickling it. 173 return { 174 "name": self.name, 175 "args": self.args, 176 "kwargs": self.kwargs, 177 "config_space": self.config_space, 178 "flop": self.flop, 179 "target": self.target, 180 "target_host": self.target_host, 181 } 182 183 def __setstate__(self, state): 184 self.name = state["name"] 185 self.args = state["args"] 186 self.kwargs = state["kwargs"] 187 self.config_space = state["config_space"] 188 self.func = TASK_TABLE.get(state["name"], _raise_error) 189 self.flop = state["flop"] 190 self.target = state["target"] 191 self.target_host = state["target_host"] 192 193 def __repr__(self): 194 return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % ( 195 self.name, 196 self.args, 197 self.kwargs, 198 self.workload, 199 ) 200 201 202TASK_TABLE = {} 203 204 205class TaskTemplate(object): 206 """ 207 Task template is used to creates a tunable AutoTVM task. 208 209 It can be defined by a pair of compute and schedule function using 210 `_register_task_compute` and `_register_task_schedule`, 211 or by a customized task creation function that is more flexible using 212 `_register_customized_task`. 213 214 Note that when customized func is registered, compute and schedule function 215 will be ignored 216 """ 217 218 def __init__(self): 219 self.fcompute = None 220 self.fschedule = None 221 self.fcustomized = None 222 223 def __call__(self, *args, **kwargs): 224 args = deserialize_args(args) 225 if self.fcustomized is None: 226 return self._default_func(*args, **kwargs) 227 assert callable(self.fcustomized) 228 return self.fcustomized(*args, **kwargs) 229 230 def _default_func(self, *args, **kwargs): 231 assert callable(self.fcompute) and callable(self.fschedule) 232 out = self.fcompute(*args, **kwargs) 233 arg_bufs = [out] + self._get_inputs(out) 234 s = self.fschedule([out]) 235 return s, arg_bufs 236 237 @staticmethod 238 def _get_inputs(out): 239 inputs = [] 240 queue = [out] 241 hash_set = set() 242 while queue: 243 t = queue.pop(0) 244 if isinstance(t.op, tensor.PlaceholderOp): 245 inputs.append(t) 246 else: 247 input_tensors = [t for t in t.op.input_tensors if t not in hash_set] 248 queue.extend(input_tensors) 249 hash_set.update(input_tensors) 250 return inputs 251 252 253def _register_task_compute(name, func=None): 254 """Register compute function to autotvm task 255 256 Parameters 257 ---------- 258 name: str 259 The task name 260 261 func: None or callable 262 If it is None, return a decorator. 263 If is callable, decorate this function. 264 265 Returns 266 ------- 267 decorator: callable 268 A decorator 269 """ 270 271 def _do_reg(f): 272 if name not in TASK_TABLE: 273 TASK_TABLE[name] = TaskTemplate() 274 tmpl = TASK_TABLE[name] 275 if tmpl.fcompute is not None: 276 raise ValueError("Compute is already registered in autoTVM task %s" % name) 277 tmpl.fcompute = f 278 return f 279 280 if func: 281 return _do_reg(func) 282 return _do_reg 283 284 285def _register_task_schedule(name, func=None): 286 """Register schedule function to autotvm task 287 288 Parameters 289 ---------- 290 name: str 291 The task name 292 293 func: None or callable 294 If it is None, return a decorator. 295 If is callable, decorate this function. 296 297 Returns 298 ------- 299 decorator: callable 300 A decorator 301 """ 302 303 def _do_reg(f): 304 if name not in TASK_TABLE: 305 TASK_TABLE[name] = TaskTemplate() 306 tmpl = TASK_TABLE[name] 307 if tmpl.fschedule is not None: 308 raise ValueError("Schedule is already registered in autoTVM task %s" % name) 309 tmpl.fschedule = f 310 return f 311 312 if func: 313 return _do_reg(func) 314 return _do_reg 315 316 317def _register_customized_task(name, func=None): 318 """Register a customized function to AutoTVM task. 319 320 Parameters 321 ---------- 322 name: str 323 The task name 324 325 func: None or callable 326 If it is None, return a decorator. 327 If is callable, decorate this function. 328 329 Returns 330 ------- 331 decorator: callable 332 A decorator 333 """ 334 335 def _do_reg(f): 336 if name not in TASK_TABLE: 337 TASK_TABLE[name] = TaskTemplate() 338 tmpl = TASK_TABLE[name] 339 if tmpl.fcustomized is not None: 340 raise ValueError("Customized func is already registered in autoTVM task %s" % name) 341 tmpl.fcustomized = f 342 return f 343 344 if func: 345 return _do_reg(func) 346 return _do_reg 347 348 349def template(task_name, func=None): 350 """Decorate a function as a tunable schedule template. 351 352 Parameters 353 ---------- 354 task_name: str 355 The task name 356 357 func: None or callable 358 A callable template function. 359 If it is None, return a decorator. 360 If is callable, decorate this function. 361 362 Returns 363 ------- 364 func: callable 365 The decorated function 366 367 Examples 368 -------- 369 The following code is a tunable template for a blocked matrix multiplication 370 371 .. code-block:: python 372 373 @autotvm.template("matmul") 374 def matmul(N, L, M, dtype): 375 A = te.placeholder((N, L), name='A', dtype=dtype) 376 B = te.placeholder((L, M), name='B', dtype=dtype) 377 378 k = te.reduce_axis((0, L), name='k') 379 C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name='C') 380 s = te.create_schedule(C.op) 381 382 # schedule 383 y, x = s[C].op.axis 384 k = s[C].op.reduce_axis[0] 385 386 ##### define space begin ##### 387 cfg = autotvm.get_config() 388 cfg.define_split("tile_y", y, num_outputs=2) 389 cfg.define_split("tile_x", x, num_outputs=2) 390 ##### define space end ##### 391 392 # schedule according to config 393 yo, yi = cfg["tile_y"].apply(s, C, y) 394 xo, xi = cfg["tile_x"].apply(s, C, x) 395 396 s[C].reorder(yo, xo, k, yi, xi) 397 398 return s, [A, B, C] 399 """ 400 401 def _decorate(f): 402 def wrapper(*args, **kwargs): 403 assert not kwargs, "Do not support kwargs in template function call" 404 workload = args_to_workload(args, task_name) 405 tgt = Target.current() 406 cfg = DispatchContext.current.query(tgt, workload) 407 with ApplyConfig(cfg): 408 return f(*args, **kwargs) 409 410 _register_customized_task(task_name, f) 411 return wrapper 412 413 if func: 414 return _decorate(func) 415 return _decorate 416 417 418def create(task_name, args, target, target_host=None): 419 """Create a tuning task and initialize its search space 420 421 Parameters 422 ---------- 423 task_name : str 424 The AutoTVM task name 425 args : List 426 Positional arguments 427 target : Target 428 The compilation target 429 target_host: Target, optional 430 The compilation target for host side 431 432 Returns 433 ------- 434 tsk: Task 435 a task object 436 """ 437 args = serialize_args(args) 438 ret = Task(task_name, args) 439 440 if isinstance(target, str): 441 target = Target(target) 442 443 # init config space 444 ret.config_space = ConfigSpace() 445 446 ctx = ApplyConfig(ret.config_space) 447 with ctx: 448 with target: 449 sch, _ = ret.func(*args) 450 ret.config_space.code_hash = getattr(sch, "code_hash", None) 451 452 ret.flop = ret.config_space.flop or compute_flop(sch) 453 ret.target = target 454 ret.target_host = target_host 455 456 return ret 457 458 459def get_config(): 460 """Get current config object 461 462 Returns 463 ------- 464 cfg: ConfigSpace or ConfigEntity 465 The current config 466 """ 467 tgt = Target.current(allow_none=True) 468 return DispatchContext.current.query(tgt, None) 469 470 471class FlopCalculationError(RuntimeError): 472 """Error happens when estimating FLOP for a compute op""" 473 474 475def compute_flop(sch): 476 """Calculate number of FLOP (floating number operations) of the compute ops in a schedule 477 478 Parameters 479 ---------- 480 sch: tvm.te.schedule.Schedule 481 schedule 482 483 Returns 484 ------- 485 flop: int 486 number of FLOP in this schedule 487 """ 488 489 def _prod_length(axes): 490 """compute product of the lengths of a list of axes""" 491 try: 492 num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes])) 493 except ValueError: 494 raise FlopCalculationError("The length of axis is not constant. ") 495 return num_iter 496 497 def _count_flop(exp): 498 """compute flop for a single expression""" 499 if isinstance(exp, expr.Reduce): 500 num_iter = _prod_length(exp.axis) 501 combiner = exp.combiner.result 502 source = exp.source 503 if len(combiner) != 1: 504 raise FlopCalculationError("Found multiple output in the combiner of reduce op") 505 if len(source) != 1: 506 raise FlopCalculationError("Found multiple output in the source of reduce op") 507 return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0])) 508 if isinstance(exp, (expr.FloatImm, expr.IntImm)): 509 return 0 510 if isinstance(exp, expr.Cast): 511 return _count_flop(exp.value) 512 if isinstance(exp, expr.Var): 513 return 0 514 if isinstance( 515 exp, 516 ( 517 expr.Add, 518 expr.Sub, 519 expr.Mul, 520 expr.Div, 521 expr.Mod, 522 expr.FloorDiv, 523 expr.FloorMod, 524 expr.Max, 525 expr.Min, 526 expr.EQ, 527 expr.NE, 528 expr.LT, 529 expr.LE, 530 expr.GT, 531 expr.GE, 532 expr.And, 533 expr.Or, 534 expr.Not, 535 ), 536 ): 537 base = 1 538 539 if isinstance(exp, expr.Not): # unary 540 return base + _count_flop(exp.a) 541 542 return base + _count_flop(exp.a) + _count_flop(exp.b) 543 if isinstance(exp, expr.Select): 544 return _count_flop(exp.condition) + max( 545 _count_flop(exp.true_value), _count_flop(exp.false_value) 546 ) 547 if isinstance(exp, expr.ProducerLoad): 548 # Ignore flops from indexing expressions. 549 return 0 550 551 if isinstance(exp, expr.Call): 552 return sum([_count_flop(x) for x in exp.args]) 553 554 raise FlopCalculationError("Found unsupported operator in the compute expr") 555 556 def traverse(ops): 557 """accumulate flops""" 558 ret = 0 559 for op in ops: 560 if isinstance(op, tensor.ComputeOp): 561 num_element = _prod_length(op.axis) 562 563 body = op.body 564 if len(body) != 1: 565 raise FlopCalculationError("Found multiple output in the compute") 566 exp = body[0] 567 568 ret += num_element * _count_flop(exp) 569 ret += traverse([t.op for t in op.input_tensors]) 570 571 elif isinstance(op, tensor.PlaceholderOp): 572 pass 573 else: 574 raise FlopCalculationError( 575 "Only support te.compute currently. " 576 "Other ops like tvm.te.scan/te.extern is not supported" 577 ) 578 return ret 579 580 try: 581 ret = traverse(sch.outputs) 582 except FlopCalculationError as exc: 583 raise RuntimeError( 584 "FLOP estimator fails for this operator. Error msg: " 585 + str(exc) 586 + ". Please use `cfg.add_flop` to manually set " 587 "FLOP for this operator" 588 ) 589 590 if ret == 0: 591 raise RuntimeError( 592 "Cannot find float number operation in this operator. " 593 "Please use `cfg.add_flop` to manually set " 594 "FLOP for this operator" 595 ) 596 return ret 597