1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""The build utils in python. 18 19This module provides the functions to transform schedule to 20LoweredFunc and compiled Module. 21""" 22from __future__ import absolute_import as _abs 23import warnings 24 25from ._ffi.function import Function 26from ._ffi.node import NodeBase, register_node 27from . import api 28from . import _api_internal 29from . import tensor 30from . import schedule 31from . import expr 32from . import ir_pass 33from . import stmt as _stmt 34from . import container 35from . import module 36from . import codegen 37from . import ndarray 38from . import target as _target 39from . import make 40 41class DumpIR(object): 42 """ 43 Dump IR for each pass. 44 With it, you can dump ir just like gcc/llvm. 45 46 How to use: 47 ----------- 48 .. code-block:: python 49 50 with tvm.build_config(dump_pass_ir=True) 51 run() 52 """ 53 scope_level = 0 54 def __init__(self): 55 self._pass_id = 0 56 self._recover_list = [] 57 58 def decorate(self, func): 59 """ decorate the pass function""" 60 def dump(*args, **kwargs): 61 """dump function""" 62 retv = func(*args, **kwargs) 63 if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)): 64 return retv 65 fname = func.func_name if hasattr(func, 'func_name') else func.__name__ 66 pname = str(self._pass_id) + "_" + fname + "_ir.cc" 67 with open(pname, "a") as f: 68 out = retv.body if isinstance(retv, container.LoweredFunc) else retv 69 f.write(str(out)) 70 if isinstance(retv, container.Array): 71 for x in retv: 72 out = x.body if isinstance(x, container.LoweredFunc) else x 73 f.write("---------%s\n%s\n-----------\n"%(x.name, str(out))) 74 self._pass_id += 1 75 return retv 76 return dump 77 78 def decorate_irpass(self): 79 """decorate ir_pass and ScheduleOps""" 80 self._old_sgpass = schedule.ScheduleOps 81 schedule.ScheduleOps = self.decorate(schedule.ScheduleOps) 82 vset = vars(ir_pass) 83 k = v = 0 84 def recover(): 85 vset[k] = v 86 for k, v in vset.items(): 87 self._recover_list.append(recover) 88 vset[k] = self.decorate(v) if isinstance(v, Function) else v 89 90 def decorate_custompass(self, custom_pass): 91 """decorate given list of custom passes, and return decorated passes""" 92 custom_pass = custom_pass if custom_pass else [] 93 pass_list = [] 94 for idx, x in enumerate(custom_pass): 95 x[1].__name__ = "custom{}_phase{}".format(idx, x[0]) 96 pass_list += [(x[0], self.decorate(x[1]))] 97 return pass_list 98 99 def enter(self): 100 """only decorate outermost nest""" 101 if DumpIR.scope_level > 0: 102 return 103 self.decorate_irpass() 104 self._pass_id = 0 105 DumpIR.scope_level += 1 106 107 def exit(self): 108 """recover outermost nest""" 109 if DumpIR.scope_level > 1: 110 return 111 # recover decorated functions 112 for f in self._recover_list: 113 f() 114 schedule.ScheduleOps = self._old_sgpass 115 DumpIR.scope_level -= 1 116 117 118@register_node 119class BuildConfig(NodeBase): 120 """Configuration scope to set a build config option. 121 122 Note 123 ---- 124 This object is backed by node system in C++, with arguments that can be 125 exchanged between python and C++. 126 127 Do not construct directly, use build_config instead. 128 129 The fields that are backed by the C++ node are immutable once an instance 130 is constructed. See _node_defaults for the fields. 131 """ 132 133 _node_defaults = { 134 "auto_unroll_max_step": 0, 135 "auto_unroll_max_depth": 8, 136 "auto_unroll_max_extent": 0, 137 "unroll_explicit": True, 138 "detect_global_barrier": False, 139 "partition_const_loop": False, 140 "offset_factor": 0, 141 "data_alignment": -1, 142 "restricted_func": True, 143 "double_buffer_split_loop": 1, 144 "dump_pass_ir": False, 145 "instrument_bound_checkers": False, 146 "disable_select_rewriting": False, 147 "disable_vectorize": False, 148 "disable_assert": False 149 } 150 _dump_ir = DumpIR() 151 152 # pylint: disable=no-member 153 def __init__(self, handle): 154 """Initialize the function with handle 155 156 Parameters 157 ---------- 158 handle : SymbolHandle 159 the handle to the underlying C++ Symbol 160 """ 161 super(BuildConfig, self).__init__(handle) 162 self.handle = handle 163 164 @property 165 def add_lower_pass(self): 166 size = _api_internal._BuildConfigGetAddLowerPassInfo(self) 167 result = [] 168 for i in range(size): 169 phase = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, True) 170 func = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, False) 171 result += [(phase, func)] 172 return result 173 174 @add_lower_pass.setter 175 def add_lower_pass(self, value): 176 add_lower_pass_args = [] 177 for x in value: 178 add_lower_pass_args += [x[0], x[1]] 179 _api_internal._BuildConfigSetAddLowerPass(self, *add_lower_pass_args) 180 181 def __enter__(self): 182 # pylint: disable=protected-access 183 _api_internal._EnterBuildConfigScope(self) 184 if self.dump_pass_ir: 185 BuildConfig._dump_ir.enter() 186 return self 187 188 def __exit__(self, ptype, value, trace): 189 if self.dump_pass_ir: 190 BuildConfig._dump_ir.exit() 191 _api_internal._ExitBuildConfigScope(self) 192 193 def __setattr__(self, name, value): 194 if name in BuildConfig._node_defaults: 195 raise AttributeError( 196 "'%s' object cannot set attribute '%s'" % (str(type(self)), name)) 197 return super(BuildConfig, self).__setattr__(name, value) 198 199 200def current_build_config(): 201 """Get the current build configuration.""" 202 return _api_internal._GetCurrentBuildConfig() 203 204 205def build_config(**kwargs): 206 """Configure the build behavior by setting config variables. 207 208 Parameters 209 ---------- 210 auto_unroll_max_step: int, default=0 211 Threshold of number of steps in the loop to be automatically unrolled. 212 This takes inner loop count into consideration. 213 214 auto_unroll_max_depth: int, default=8 215 The maximum nested level of loops that can be automatically unrolled. 216 217 unroll_explicit: bool, default=True 218 Whether explicitly unroll the loop, if set false, the unroll hint will 219 be passed to the CodeGen phase, which may generate pragma unroll hint. 220 Set this to be true if CodeGen support unroll pragma and 221 when we want to be more readable. 222 223 detect_global_barrier: bool, default=True 224 Whether detect global barrier. 225 226 partition_const_loop: bool, default=False 227 Whether partition const loop 228 229 data_alignment: int, optional 230 The alignment of data pointer in bytes. 231 If -1 is passed, the alignment will be set to TVM's internal default. 232 233 offset_factor: int, default=0 234 The factor used in default buffer declaration. 235 If specified as 0, offset field is not used. 236 237 restricted_func: bool, default=True 238 Whether build restricted function. 239 That is each buffer argument to the function are guaranteed 240 not to overlap. This enables more optimization. 241 Corresponds to restricted keyword in C99 242 243 double_buffer_split_loop: int, default=2 244 Whether split the loop with factor. If it is zero, no splitting will happen. 245 It it is bigger than one, the logic will do a split with factor equals the integer 246 and unroll the inner loop. This allows the buffer fetching won't contain condition. 247 248 add_lower_pass: list of tuple (phase, function(Stmt->Stmt)), default=None 249 phase contains an integer on which optimization pass we apply the pass. 250 Additional lowering passes to be applied before make_api. 251 252 dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False 253 254 Returns 255 ------- 256 config: BuildConfig 257 The build configuration 258 """ 259 node_args = {k: v if k not in kwargs else kwargs[k] 260 for k, v in BuildConfig._node_defaults.items()} 261 config = make.node("BuildConfig", **node_args) 262 263 if "add_lower_pass" in kwargs: 264 config.add_lower_pass = kwargs["add_lower_pass"] 265 266 return config 267 268def get_binds(args, compact=False, binds=None): 269 """Internal function to get binds and arg_list given arguments. 270 271 Parameters 272 ---------- 273 args : list of Buffer or Tensor or Var 274 The argument lists to the function. 275 276 compact : bool 277 If the statement has already bound to a compact buffer. 278 279 binds : dict of :any:`Tensor` to :any:`Buffer`, optional 280 Dictionary that maps the Tensor to Buffer which specified the data layout 281 requirement of the function. By default, a new compact buffer is created 282 for each tensor in the argument. 283 284 Returns 285 ------- 286 binds: dict 287 The bind specification 288 289 arg_list: list 290 The list of symbolic buffers of arguments. 291 """ 292 binds = {} if binds is None else binds.copy() 293 cfg = current_build_config() 294 arg_list = [] 295 for x in args: 296 if isinstance(x, tensor.Tensor): 297 any_dim = any(isinstance(i, expr.Var) for i in x.shape) 298 buffer_type = "auto_broadcast" if any_dim and not compact else "" 299 if x not in binds: 300 buf = api.decl_buffer(x.shape, 301 dtype=x.dtype, 302 name=x.name, 303 data_alignment=cfg.data_alignment, 304 offset_factor=cfg.offset_factor, 305 buffer_type=buffer_type) 306 binds[x] = buf 307 arg_list.append(buf) 308 else: 309 arg_list.append(binds[x]) 310 elif isinstance(x, schedule.Buffer): 311 arg_list.append(x) 312 elif isinstance(x, expr.Var): 313 arg_list.append(x) 314 else: 315 raise ValueError("args must be Tensor, Buffer or Var") 316 return binds, arg_list 317 318 319def form_body(sch): 320 """According to the given schedule, form the raw body 321 Parameters 322 ---------- 323 sch : tvm.schedule.Schedule 324 The given scheduler to form the raw body 325 326 Returns 327 ------- 328 The body formed according to the given schedule 329 """ 330 # normalize schedule first 331 sch = sch.normalize() 332 bounds = schedule.InferBound(sch) 333 stmt = schedule.ScheduleOps(sch, bounds) 334 stmt = ir_pass.InjectPrefetch(stmt) 335 return stmt 336 337 338def lower(sch, 339 args, 340 name="default_function", 341 binds=None, 342 simple_mode=False): 343 """Lowering step before build into target. 344 345 Parameters 346 ---------- 347 sch : tvm.schedule.Schedule 348 The schedule to be built 349 350 args : list of Buffer or Tensor or Var 351 The argument lists to the function. 352 353 name : str, optional 354 The name of result function. 355 356 binds : dict of :any:`Tensor` to :any:`Buffer`, optional 357 Dictionary that maps the Tensor to Buffer which specified the data layout 358 requirement of the function. By default, a new compact buffer is created 359 for each tensor in the argument. 360 361 simple_mode : bool, optional 362 Whether only output simple and compact statement, this will skip 363 LoopPartition, api wrapper generation and Unrolling. 364 365 Returns 366 ------- 367 f : LoweredFunc or Stmt 368 The result function, if with_api_wrapper=False 369 Then the Stmt before make api is returned. 370 """ 371 cfg = current_build_config() 372 add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] 373 if cfg.dump_pass_ir: 374 add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass) 375 lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] 376 lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] 377 lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] 378 lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2] 379 380 # Phase 0 381 if isinstance(sch, schedule.Schedule): 382 stmt = form_body(sch) 383 384 for f in lower_phase0: 385 stmt = f(stmt) 386 387 compact = ir_pass.VerifyCompactBuffer(stmt) 388 binds, arg_list = get_binds(args, compact, binds) 389 390 # Phase 1 391 stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) 392 stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) 393 stmt = ir_pass.CanonicalSimplify(stmt) 394 for f in lower_phase1: 395 stmt = f(stmt) 396 397 # Phase 2 398 if not simple_mode: 399 stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) 400 if cfg.disable_vectorize: 401 stmt = ir_pass.SkipVectorize(stmt) 402 else: 403 stmt = ir_pass.VectorizeLoop(stmt) 404 stmt = ir_pass.InjectVirtualThread(stmt) 405 stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) 406 stmt = ir_pass.StorageRewrite(stmt) 407 stmt = ir_pass.UnrollLoop( 408 stmt, 409 cfg.auto_unroll_max_step, 410 cfg.auto_unroll_max_depth, 411 cfg.auto_unroll_max_extent, 412 cfg.unroll_explicit) 413 for f in lower_phase2: 414 stmt = f(stmt) 415 416 # Phase 3 417 stmt = ir_pass.Simplify(stmt) 418 stmt = ir_pass.RemoveNoOp(stmt) 419 if not cfg.disable_select_rewriting: 420 stmt = ir_pass.RewriteUnsafeSelect(stmt) 421 for f in lower_phase3: 422 stmt = f(stmt) 423 # Instrument BoundCheckers 424 if cfg.instrument_bound_checkers: 425 stmt = ir_pass.InstrumentBoundCheckers(stmt) 426 if simple_mode: 427 return stmt 428 429 return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) 430 431 432def _build_for_device(flist, target, target_host): 433 """Build the lowered functions for a device with the given compilation 434 target. 435 436 Parameters 437 ---------- 438 flist : list of LoweredFunc 439 The schedule to be built. 440 441 target : str or :any:`tvm.target.Target` 442 The target and option of the compilation. 443 444 target_host : str or :any:`tvm.target.Target` 445 The host compilation target. 446 447 Returns 448 ------- 449 fhost : list of LoweredFunc 450 A list of lowered functions for the host. 451 452 mdev : tvm.module 453 A module that contains device code. 454 """ 455 target = _target.create(target) 456 device_type = ndarray.context(target.target_name, 0).device_type 457 fhost = [] 458 fdevice = [] 459 for func in flist: 460 if not ir_pass.VerifyMemory(func, device_type): 461 raise ValueError( 462 "Direct host side access to device memory is detected in %s. " 463 "Did you forget to bind?" % func.name) 464 if func.func_type == container.LoweredFunc.MixedFunc: 465 if current_build_config().detect_global_barrier: 466 func = ir_pass.ThreadSync(func, "global") 467 func = ir_pass.ThreadSync(func, "shared") 468 func = ir_pass.ThreadSync(func, "warp") 469 func = ir_pass.InferFragment(func) 470 warp_size = target.thread_warp_size 471 func = ir_pass.LowerThreadAllreduce(func, warp_size) 472 fsplits = [s for s in ir_pass.SplitHostDevice(func)] 473 fhost.append(fsplits[0]) 474 for x in fsplits[1:]: 475 fdevice.append(x) 476 elif func.func_type == container.LoweredFunc.HostFunc: 477 fhost.append(func) 478 elif func.func_type == container.LoweredFunc.DeviceFunc: 479 fdevice.append(func) 480 else: 481 raise ValueError("unknown function type %d" % func.func_type) 482 483 for i, func in enumerate(fdevice): 484 warp_size = target.thread_warp_size 485 fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) 486 487 if "gpu" in target.keys and not fdevice: 488 warnings.warn( 489 "Specified target %s, but cannot find device code, did you do " 490 "bind?" % target) 491 492 fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] 493 fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] 494 495 if device_type == ndarray.cpu(0).device_type and target_host == target: 496 assert not fdevice 497 498 target_host = _target.create(target_host) 499 fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice] 500 fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] 501 fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] 502 fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] 503 fhost = [ir_pass.CombineContextCall(x) for x in fhost] 504 mdev = codegen.build_module(fdevice, str(target)) if fdevice else None 505 506 return fhost, mdev 507 508 509def build(inputs, 510 args=None, 511 target=None, 512 target_host=None, 513 name="default_function", 514 binds=None): 515 """Build a function with arguments as signature. Code will be generated 516 for devices coupled with target information. 517 518 Parameters 519 ---------- 520 inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list 521 The schedule to be built 522 523 args : list of Buffer or Tensor or Var, optional 524 The argument lists to the function. 525 526 target : str or :any:`tvm.target.Target`, optional 527 The target and option of the compilation. 528 529 target_host : str or :any:`tvm.target.Target` optional 530 Host compilation target, if target is device. 531 When TVM compiles device specific program such as CUDA, 532 we also need host(CPU) side code to interact with the driver 533 setup the dimensions and parameters correctly. 534 target_host is used to specify the host side codegen target. 535 By default, llvm is used if it is enabled, 536 otherwise a stackvm intepreter is used. 537 538 name : str, optional 539 The name of result function. 540 541 binds : dict, optional 542 Dictionary that maps the binding of symbolic buffer to Tensor. 543 By default, a new buffer is created for each tensor in the argument. 544 545 Returns 546 ------- 547 ret : tvm.module 548 A module that combines both host and device code. 549 550 Examples 551 ________ 552 There are two typical example uses of this function depending on the type 553 of the argument `inputs`: 554 1. it is a list of lowered functions: 555 556 .. code-block:: python 557 558 n = 2 559 A = tvm.placeholder((n,), name='A') 560 B = tvm.placeholder((n,), name='B') 561 C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') 562 s = tvm.create_schedule(C.op) 563 f = tvm.lower(s, [A, B, C], name="test_add") 564 m = tvm.build(f, target="llvm") 565 566 2. it is a dict of compilation target to list of lowered functions: 567 568 .. code-block:: python 569 570 n = 2 571 A = tvm.placeholder((n,), name='A') 572 B = tvm.placeholder((n,), name='B') 573 C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') 574 s1 = tvm.create_schedule(C.op) 575 with tvm.target.cuda() as cuda_tgt: 576 s2 = topi.cuda.schedule_injective(cuda_tgt, [C]) 577 f1 = tvm.lower(s1, [A, B, C], name="test_add1") 578 f2 = tvm.lower(s2, [A, B, C], name="test_add2") 579 m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm") 580 581 Note 582 ---- 583 See the note on :any:`tvm.target` on target string format. 584 """ 585 if isinstance(inputs, schedule.Schedule): 586 if args is None: 587 raise ValueError("args must be given for build from schedule") 588 flist = lower(inputs, args, 589 name=name, 590 binds=binds) 591 if isinstance(flist, container.LoweredFunc): 592 flist = [flist] 593 elif isinstance(inputs, container.LoweredFunc): 594 if args: 595 raise ValueError("args must be done when build from LoweredFunc.") 596 flist = [inputs] 597 elif isinstance(inputs, (list, tuple, container.Array)): 598 flist = inputs 599 elif not isinstance(inputs, (dict, container.Map)): 600 raise ValueError("inputs must be Schedule, LoweredFunc, list of " 601 "LoweredFunc, or dict of target to list of " 602 "LoweredFunc.") 603 604 if not isinstance(inputs, (dict, container.Map)): 605 target = _target.current_target() if target is None else target 606 target = target if target else "llvm" 607 target_flist = {target: flist} 608 else: 609 target_flist = inputs 610 611 for tar, flist in target_flist.items(): 612 if not isinstance(tar, (str, _target.Target)): 613 raise ValueError("The key of inputs must be str or " 614 "_target.Target when inputs is dict.") 615 fname_set = set() 616 for x in flist: 617 if not isinstance(x, container.LoweredFunc): 618 raise ValueError("inputs must be Schedule, LoweredFunc, list " 619 "of LoweredFunc, or dict of str to list of " 620 "LoweredFunc.") 621 if x.name in fname_set: 622 raise ValueError("Duplicate function name %s" % x.name) 623 fname_set.add(x.name) 624 625 if not target_host: 626 for tar, _ in target_flist.items(): 627 tar = _target.create(tar) 628 device_type = ndarray.context(tar.target_name, 0).device_type 629 if device_type == ndarray.cpu(0).device_type: 630 target_host = tar 631 break 632 if not target_host: 633 target_host = "llvm" if module.enabled("llvm") else "stackvm" 634 635 fhost_all = [] 636 device_modules = [] 637 for tar, flist in target_flist.items(): 638 fhost, mdev = _build_for_device(flist, tar, target_host) 639 # Save the current lowered functions of the host and the device module. 640 fhost_all += fhost 641 device_modules.append(mdev) 642 643 # Generate a unified host module. 644 mhost = codegen.build_module(fhost_all, str(target_host)) 645 646 # Import all modules. 647 for mdev in device_modules: 648 if mdev: 649 mhost.import_module(mdev) 650 return mhost 651