1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=unused-import 18"""The computation schedule api of TVM.""" 19import tvm._ffi 20from tvm._ffi.base import string_types 21 22from tvm.runtime import Object, convert 23from tvm.ir import container as _container 24from tvm.tir import IterVar, Buffer 25 26from . import tensor as _tensor 27from . import _ffi_api 28 29 30@tvm._ffi.register_object 31class Split(Object): 32 """Split operation on axis.""" 33 34 35@tvm._ffi.register_object 36class Fuse(Object): 37 """Fuse operation on axis.""" 38 39 40@tvm._ffi.register_object 41class Singleton(Object): 42 """Singleton axis.""" 43 44 45def create_schedule(ops): 46 """Create a schedule for list of ops 47 48 Parameters 49 ---------- 50 ops : list of Operations 51 The source expression. 52 53 Returns 54 ------- 55 sch : schedule.Schedule 56 The created schedule. 57 """ 58 if not isinstance(ops, (list, _container.Array)): 59 ops = [ops] 60 return _ffi_api.CreateSchedule(ops) 61 62 63@tvm._ffi.register_object 64class Schedule(Object): 65 """Schedule for all the stages.""" 66 67 def __getitem__(self, k): 68 if isinstance(k, _tensor.Tensor): 69 k = k.op 70 if not isinstance(k, _tensor.Operation): 71 raise ValueError("Expect schedule key to be Tensor or Operation") 72 if k not in self.stage_map: 73 raise ValueError("Cannot find the operation %s in schedule" % (str(k))) 74 return self.stage_map[k] 75 76 def normalize(self): 77 """Build a normalized schedule from the current schedule. 78 79 Insert necessary rebase to make certain iter var to start from 0. 80 This is needed before bound inference and followup step. 81 82 Returns 83 ------- 84 sch : Schedule 85 The normalized schedule. 86 """ 87 return _ffi_api.ScheduleNormalize(self) 88 89 def create_group(self, outputs, inputs, include_inputs=False): 90 """Create stage group by giving output and input boundary. 91 92 The operators between outputs and inputs are placed as member of group. 93 outputs are include in the group, while inputs are not included. 94 95 Parameters 96 ---------- 97 outputs : list of Tensors 98 The outputs of the group. 99 100 inputs : list of Tensors 101 The inputs of the group. 102 103 include_inputs : boolean, optional 104 Whether include input operations in the group if they are used by outputs. 105 106 Returns 107 ------- 108 group : Stage 109 A virtual stage represents the group, user can use compute_at to move 110 the attachment point of the group. 111 """ 112 if isinstance(outputs, _tensor.Tensor): 113 outputs = [outputs] 114 if isinstance(inputs, _tensor.Tensor): 115 inputs = [inputs] 116 return _ffi_api.ScheduleCreateGroup(self, outputs, inputs, include_inputs) 117 118 def cache_read(self, tensor, scope, readers): 119 """Create a cache read of original tensor for readers. 120 121 This will mutate the body of the readers. 122 A new cache stage will be created for the tensor. 123 Call this before doing any split/fuse schedule. 124 125 Parameters 126 ---------- 127 tensor : Tensor 128 The tensor to be cached. 129 scope : str 130 The scope of cached 131 readers : list of Tensor or Operation 132 The readers to read the cache. 133 134 Returns 135 ------- 136 cache : Tensor 137 The created cache tensor. 138 """ 139 if isinstance(readers, (_tensor.Tensor, _tensor.Operation)): 140 readers = [readers] 141 readers = [t.op if isinstance(t, _tensor.Tensor) else t for t in readers] 142 return _ffi_api.ScheduleCacheRead(self, tensor, scope, readers) 143 144 def cache_write(self, tensor, scope): 145 """Create a cache write of original tensor, before storing into tensor. 146 147 This will mutate the body of the tensor. 148 A new cache stage will created before feed into the tensor. 149 150 This function can be used to support data layout transformation. 151 If there is a split/fuse/reorder on the data parallel axis of tensor 152 before cache_write is called. The intermediate cache stores 153 the data in the layout as the iteration order of leave axis. 154 The data will be transformed back to the original layout in the original tensor. 155 User can further call compute_inline to inline the original layout and keep 156 the data stored in the transformed layout. 157 158 Parameters 159 ---------- 160 tensor : Tensor, list or tuple 161 The tensors to be feed to. All the tensors must be produced by one computeOp 162 scope : str 163 The scope of cached 164 165 Returns 166 ------- 167 cache : Tensor 168 The created cache tensor. 169 """ 170 return _ffi_api.ScheduleCacheWrite(self, tensor, scope) 171 172 def rfactor(self, tensor, axis, factor_axis=0): 173 """Factor a reduction axis in tensor's schedule to be an explicit axis. 174 175 This will create a new stage that generated the new tensor with axis 176 as the first dimension. The tensor's body will be rewritten as a reduction 177 over the factored tensor. 178 179 Parameters 180 ---------- 181 tensor : Tensor 182 The tensor to be factored. 183 axis : IterVar 184 The reduction axis in the schedule to be factored. 185 factor_axis : int 186 The position where the new axis is placed. 187 188 Returns 189 ------- 190 tfactor : Tensor or Array of Tensor 191 The created factored tensor. 192 """ 193 factored = _ffi_api.ScheduleRFactor(self, tensor, axis, factor_axis) 194 return factored[0] if len(factored) == 1 else factored 195 196 197@tvm._ffi.register_object 198class Stage(Object): 199 """A Stage represents schedule for one operation.""" 200 201 def split(self, parent, factor=None, nparts=None): 202 """Split the stage either by factor providing outer scope, or both 203 204 Parameters 205 ---------- 206 parent : IterVar 207 The parent iter var. 208 209 factor : Expr, optional 210 The splitting factor 211 212 nparts : Expr, optional 213 The number of outer parts. 214 215 Returns 216 ------- 217 outer : IterVar 218 The outer variable of iteration. 219 220 inner : IterVar 221 The inner variable of iteration. 222 """ 223 if nparts is not None: 224 if factor is not None: 225 raise ValueError("Do not need to provide both outer and nparts") 226 outer, inner = _ffi_api.StageSplitByNParts(self, parent, nparts) 227 else: 228 if factor is None: 229 raise ValueError("Either nparts or factor need to be provided") 230 outer, inner = _ffi_api.StageSplitByFactor(self, parent, factor) 231 return outer, inner 232 233 def fuse(self, *args): 234 """Fuse multiple consecutive iteration variables into a single iteration variable. 235 236 fused = fuse(...fuse(fuse(args[0], args[1]), args[2]),..., args[-1]) 237 The order is from outer to inner. 238 239 Parameters 240 ---------- 241 args : list of IterVars 242 Itervars that proceeds each other 243 244 Returns 245 ------- 246 fused : IterVar 247 The fused variable of iteration. 248 """ 249 fused = _ffi_api.StageFuse(self, args) 250 return fused 251 252 def set_scope(self, scope): 253 """Set the thread scope of this stage 254 255 Parameters 256 ---------- 257 scope : str 258 The thread scope of this stage 259 """ 260 return _ffi_api.StageSetScope(self, scope) 261 262 def bind(self, ivar, thread_ivar): 263 """Bind ivar to thread index thread_ivar 264 265 Parameters 266 ---------- 267 ivar : IterVar 268 The iteration to be binded to thread. 269 270 thread_ivar : IterVar 271 The thread to be binded. 272 """ 273 _ffi_api.StageBind(self, ivar, thread_ivar) 274 275 def env_threads(self, threads): 276 """Mark threads to be launched at the outer scope of composed op. 277 278 Parameters 279 ---------- 280 threads : list of threads 281 The threads to be launched. 282 """ 283 if isinstance(threads, IterVar): 284 threads = [threads] 285 _ffi_api.StageEnvThreads(self, threads) 286 287 def set_store_predicate(self, predicate): 288 """Set predicate under which store to the array can be performed. 289 290 Use this when there are duplicated threads doing the same store and we only 291 need one of them to do the store. 292 293 Parameters 294 ---------- 295 predicate : Expr 296 The guard condition fo store. 297 """ 298 _ffi_api.StageSetStorePredicate(self, predicate) 299 300 def compute_at(self, parent, scope): 301 """Attach the stage at parent's scope 302 303 Parameters 304 ---------- 305 parent : Stage 306 The parent stage 307 308 scope : IterVar 309 The loop scope t be attached to. 310 """ 311 _ffi_api.StageComputeAt(self, parent, scope) 312 313 def compute_inline(self): 314 """Mark stage as inline 315 316 Parameters 317 ---------- 318 parent : Stage 319 The parent stage 320 """ 321 _ffi_api.StageComputeInline(self) 322 323 def compute_root(self): 324 """Attach the stage at parent, and mark it as root 325 326 Parameters 327 ---------- 328 parent : Stage 329 The parent stage 330 """ 331 _ffi_api.StageComputeRoot(self) 332 333 def reorder(self, *args): 334 """reorder the arguments in the specified order. 335 336 Parameters 337 ---------- 338 args : list of IterVar 339 The order to be ordered 340 """ 341 _ffi_api.StageReorder(self, args) 342 343 def tile(self, x_parent, y_parent, x_factor, y_factor): 344 """Perform tiling on two dimensions 345 346 The final loop order from outmost to inner most are 347 [x_outer, y_outer, x_inner, y_inner] 348 349 Parameters 350 ---------- 351 x_parent : IterVar 352 The original x dimension 353 y_parent : IterVar 354 The original y dimension 355 x_factor : Expr 356 The stride factor on x axis 357 y_factor : Expr 358 The stride factor on y axis 359 360 Returns 361 ------- 362 x_outer : IterVar 363 Outer axis of x dimension 364 y_outer : IterVar 365 Outer axis of y dimension 366 x_inner : IterVar 367 Inner axis of x dimension 368 p_y_inner : IterVar 369 Inner axis of y dimension 370 """ 371 x_outer, y_outer, x_inner, y_inner = _ffi_api.StageTile( 372 self, x_parent, y_parent, x_factor, y_factor 373 ) 374 return x_outer, y_outer, x_inner, y_inner 375 376 def vectorize(self, var): 377 """Vectorize the iteration. 378 379 Parameters 380 ---------- 381 var : IterVar 382 The iteration to be vectorize 383 """ 384 _ffi_api.StageVectorize(self, var) 385 386 def tensorize(self, var, tensor_intrin): 387 """Tensorize the computation enclosed by var with tensor_intrin 388 389 Parameters 390 ---------- 391 var : IterVar 392 The iteration boundary of tensorization. 393 394 tensor_intrin : TensorIntrin 395 The tensor intrinsic used for computation. 396 """ 397 _ffi_api.StageTensorize(self, var, tensor_intrin) 398 399 def unroll(self, var): 400 """Unroll the iteration. 401 402 Parameters 403 ---------- 404 var : IterVar 405 The iteration to be unrolled. 406 """ 407 _ffi_api.StageUnroll(self, var) 408 409 def parallel(self, var): 410 """Parallelize the iteration. 411 412 Parameters 413 ---------- 414 var : IterVar 415 The iteration to be parallelized. 416 """ 417 _ffi_api.StageParallel(self, var) 418 419 def pragma(self, var, pragma_type, pragma_value=None): 420 """Annotate the iteration with pragma 421 422 This will translate to a pragma_scope surrounding 423 the corresponding loop generated. 424 Useful to support experimental features and extensions. 425 426 Parameters 427 ---------- 428 var : IterVar 429 The iteration to be anotated 430 431 pragma_type : str 432 The pragma string to be annotated 433 434 pragma_value : Expr, optional 435 The pragma value to pass along the pragma 436 437 Note 438 ---- 439 Most pragmas are advanced/experimental features 440 and may subject to change. List of supported pragmas: 441 442 - **debug_skip_region** 443 444 Force skip the region marked by the axis and turn it into no-op. 445 This is useful for debug purposes. 446 447 - **parallel_launch_point** 448 449 Specify to launch parallel threads outside the 450 specified iteration loop. By default the threads 451 launch at the point of parallel construct. 452 This pragma moves the launching point to even outer scope. 453 The threads are launched once and reused across multiple 454 parallel constructs as BSP style program. 455 456 - **parallel_barrier_when_finish** 457 458 Insert a synchronization barrier between working threads 459 after the specified loop iteration finishes. 460 461 - **parallel_stride_pattern** 462 463 Hint parallel loop to execute in strided pattern. 464 :code:`for (int i = task_id; i < end; i += num_task)` 465 466 """ 467 if isinstance(pragma_value, string_types): 468 pragma_value = convert(pragma_value) 469 _ffi_api.StagePragma(self, var, pragma_type, pragma_value) 470 471 def prefetch(self, tensor, var, offset): 472 """Prefetch the specified variable 473 474 Parameters 475 ---------- 476 tensor : Tensor 477 The tensor to be prefetched 478 var : IterVar 479 The loop point at which the prefetching is applied 480 offset : Expr 481 The number of iterations to be prefetched before actual execution 482 """ 483 _ffi_api.StagePrefetch(self, tensor, var, offset) 484 485 def storage_align(self, axis, factor, offset): 486 """Set alignment requirement for specific axis 487 488 This ensures that stride[axis] == k * factor + offset for some k. 489 This is useful to set memory layout to for more friendly memory 490 access pattern. For example, we can set alignment to be 491 factor=2, offset=1 to avoid bank conflict for thread access on 492 higher dimension in GPU shared memory. 493 494 Parameters 495 ---------- 496 axis : IterVar 497 The axis dimension to be aligned. 498 factor : int 499 The factor in alignment specification. 500 offset : int 501 The offset in the alignment specification. 502 """ 503 _ffi_api.StageStorageAlign(self, axis, factor, offset) 504 505 def double_buffer(self): 506 """Compute the current stage via double buffering. 507 508 This can only be applied to intermediate stage. 509 This will double the storage cost of the current stage. 510 Can be useful to hide load latency. 511 """ 512 _ffi_api.StageDoubleBuffer(self) 513 514 515@tvm._ffi.register_object 516class SpecializedCondition(Object): 517 """Specialized condition to enable op specialization.""" 518 519 def __init__(self, conditions): 520 """Create a specialized condition. 521 522 .. note:: 523 Conditions are represented in conjunctive joint form (CNF). 524 Each condition should be a simple expression, e.g., n > 16, 525 m % 8 == 0, etc., where n, m are tvm.Var that represents a 526 dimension in the tensor shape. 527 528 Parameters 529 ---------- 530 conditions : List of tvm.Expr 531 List of conditions in conjunctive joint form (CNF). 532 """ 533 if not isinstance(conditions, (list, _container.Array)): 534 conditions = [conditions] 535 self.__init_handle_by_constructor__(_ffi_api.CreateSpecializedCondition, conditions) 536 537 @staticmethod 538 def current(): 539 """Returns the current specialized condition""" 540 return _ffi_api.GetCurrentSpecialization() 541 542 def __enter__(self): 543 _ffi_api.EnterSpecializationScope(self) 544 return self 545 546 def __exit__(self, ptype, value, trace): 547 _ffi_api.ExitSpecializationScope(self) 548 549 550tvm._ffi._init_api("schedule", __name__) 551