1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ 18# pylint: disable=consider-using-enumerate,too-many-lines 19""" 20Template configuration space. 21 22Each template function can be parametrized by a ConfigSpace. 23The space is declared when we invoke the template function with ConfigSpace. 24During evaluation, we pass in a ConfigEntity, which contains a specific 25entity in the space. This entity contains deterministic parameters. 26""" 27from __future__ import absolute_import as _abs 28 29import itertools 30import functools 31import math 32from collections import namedtuple, OrderedDict 33import numpy as np 34 35from tvm import schedule, thread_axis 36from tvm.autotvm.util import get_const_int 37 38Axis = namedtuple('Axis', ['space', 'index']) 39 40try: 41 _long = long 42except NameError: 43 _long = int 44 45 46class InstantiationError(ValueError): 47 """Actively detected error in instantiating a template with a config, 48 raised by cfg.raise_error 49 e.g. too many unrolling, too many threads in a block 50 """ 51 52 53class TransformSpace(object): 54 """Base class for transform space 55 TransformSpace is the node in the computation graph of axes 56 57 Note 58 ---- 59 We can regard our schedule code as a transformation graph of axes. 60 Starting from raw axes in the definition of tvm.compute, we can transform these axes 61 by some operators. The operator includes 'split', 'reorder' and 'annotate'. 62 Each operator has some tunable parameters (e.g. the split factor). 63 Then the tuning process is just to find good parameters of these op. 64 65 So the all the combinations of the parameters of these op forms our search space. 66 67 Naming convention: 68 We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...) 69 We call a specific entity in a space as XXXEntity. 70 """ 71 def __init__(self): 72 self.ins = [] 73 self.num_output = 0 74 self.entities = [] 75 76 def __len__(self): 77 return len(self.entities) 78 79 def __getitem__(self, index): 80 """Get an entity of the space by index 81 82 Parameters 83 ---------- 84 index: int 85 86 Returns 87 ------- 88 transform entity 89 """ 90 return self.entities[index] 91 92 @staticmethod 93 def get_num_output(): 94 """get number of output axes after this transform 95 96 Returns 97 ------- 98 n: int 99 number of output axes 100 """ 101 return 0 102 103 104class VirtualAxis(TransformSpace): 105 """Axis placeholder in template 106 107 Parameters 108 ---------- 109 var: int or tvm.schedule.IterVar 110 If is int, return a virtual axis whose length is the provided argument. 111 If is IterVar, return a virtual axis whose length is extracted from 112 the IterVar's extent domain. 113 name: str 114 """ 115 name_ct = 0 116 117 def __init__(self, var, name=None): 118 super(VirtualAxis, self).__init__() 119 self.num_output = 1 120 121 if name is None: 122 name = 'axis_%d' % VirtualAxis.name_ct 123 VirtualAxis.name_ct += 1 124 125 self.name = name 126 if isinstance(var, (int, _long)): 127 self.length = var 128 elif isinstance(var, schedule.IterVar): 129 self.name = var.var.name 130 if var.dom is None: 131 self.length = -1 132 else: 133 self.length = get_const_int(var.dom.extent) 134 elif isinstance(var, VirtualAxis): 135 self.length = var.length 136 else: 137 raise RuntimeError("Invalid type of axis: " + str(type(var))) 138 139 @staticmethod 140 def get_num_output(var, name=None): 141 return 1 142 143 def __repr__(self): 144 return "vaxis(%s)" % self.name 145 146 147def get_factors(n): 148 """return all factors of an integer 149 150 Parameters 151 ---------- 152 n: int 153 integer to factorize 154 155 Returns 156 ------- 157 factors: list 158 List of all factors 159 """ 160 step = 2 if n % 2 else 1 161 ret = list(set( 162 functools.reduce( 163 list.__add__, ([i, n//i] for i in range(1, int(math.sqrt(n)) + 1, step) 164 if n % i == 0)))) 165 ret.sort() 166 return ret 167 168def get_pow2s(n): 169 """return all power-of-two numbers that are less or equal than the integer 170 171 Parameters 172 ---------- 173 n: int 174 integer for reference 175 176 Returns 177 ------- 178 factors: list 179 List of all power-of-two numbers 180 """ 181 return [2**x for x in range(math.floor(math.log2(n)) + 1)] 182 183class SplitSpace(TransformSpace): 184 """Split an axis for several times""" 185 def __init__(self, axes, policy, **kwargs): 186 super(SplitSpace, self).__init__() 187 axis = axes[0] 188 189 self.policy = policy 190 self.entities = [] 191 192 max_factor = kwargs.get("max_factor", 1 << 31) 193 fil = kwargs.get("filter", lambda x: True) 194 self.product = axis.length 195 self.num_output = kwargs.get("num_outputs", 0) 196 assert self.num_output > 0 197 198 if policy == 'candidate': 199 for size in kwargs["candidate"]: 200 assert len(size) == self.num_output 201 self.entities.append(SplitEntity(size)) 202 else: 203 if policy == 'verbose': 204 # Include factors and power-of-twos. May generate tails. 205 divisibles = get_factors(self.product) 206 pow2s = get_pow2s(self.product) 207 factors = [x for x in list(set(divisibles) | set(pow2s)) if x <= max_factor] 208 elif policy == 'factors': 209 # Include divisible factors. Guarantee no tails. 210 factors = [x for x in get_factors(self.product) if x <= max_factor] 211 elif policy == 'power2': 212 # Include less, equal, and round-up power-of-two numbers. May generate tails. 213 factors = [x for x in get_pow2s(self.product) if x <= max_factor] 214 else: 215 raise RuntimeError("Invalid policy: %s" % policy) 216 217 # Enforce the product of all split factors equals to the axis length 218 no_tail = kwargs.get("no_tail", policy == 'factors') 219 220 # Generate split entity by enumerating candidate factors. 221 self.factors = factors 222 self._generate_space(0, [None] * (self.num_output - 1), enforce_no_tail=no_tail) 223 224 self.entities = list(filter(fil, self.entities)) 225 226 def _generate_space(self, now, tmp_stack, enforce_no_tail=False): 227 """Generate space by DFS""" 228 if now == self.num_output - 1: 229 prod = functools.reduce(lambda x, y: x * y, tmp_stack) 230 if prod > self.product: 231 return 232 if self.product % prod == 0 or (not enforce_no_tail and prod < self.product): 233 self.entities.append(SplitEntity([-1] + tmp_stack[::-1])) 234 else: 235 for factor in self.factors: 236 tmp_stack[now] = factor 237 self._generate_space(now + 1, tmp_stack, enforce_no_tail) 238 239 @staticmethod 240 def get_num_output(axes, policy, **kwargs): 241 return kwargs["num_outputs"] 242 243 def __repr__(self): 244 return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" % 245 (self.policy, self.product, self.num_output, len(self))) 246 247 248class SplitEntity(object): 249 """ 250 A split operation with detailed parameters 251 that can apply to an axis 252 253 Parameters 254 ---------- 255 size: Array of int 256 the size of every axis after split 257 e.g. an axis of extent 128, we split it into 3 axes, a possible 258 size is [4, 4, 8] (4x4x8 = 128) 259 """ 260 def __init__(self, size): 261 self.size = size 262 263 def apply(self, sch, op, axis): 264 """Apply split to an axis 265 266 Parameters 267 ---------- 268 sch: tvm.schedule.Schedule 269 The tvm schedule 270 op: tvm.tensor.Operation 271 The stage to be applied 272 axis: tvm.schedule.IterVar 273 axis to split 274 275 Returns 276 ------- 277 axes : list of Axis 278 The transformed axes. 279 """ 280 ret = [] 281 for i in range(1, len(self.size)): 282 ax0, ax1 = sch[op].split(axis, int(np.prod(self.size[i:]))) 283 ret.append(ax0) 284 axis = ax1 285 return ret + [axis] 286 287 def __repr__(self): 288 return str(self.size) 289 290 291class ReorderSpace(TransformSpace): 292 """The parameter space for ordering an array of axes""" 293 def __init__(self, axes, policy, **kwargs): 294 super(ReorderSpace, self).__init__() 295 self.ins = axes 296 self.policy = policy 297 self.num_output = len(axes) 298 299 if policy == 'identity': 300 self.entities = [ReorderEntity(range(len(axes)))] 301 elif policy == 'all': 302 self.entities = [ 303 ReorderEntity(x) for x in itertools.permutations(range(len(axes)))] 304 elif policy == 'interval_all': 305 begin, end = kwargs['interval'] 306 sub_space = list(itertools.permutations(range(begin, end))) 307 prefix, suffix = tuple(range(begin)), tuple(range(end, len(axes))) 308 self.entities = [ReorderEntity(prefix + x + suffix) for x in sub_space] 309 elif policy == 'candidate': 310 candidate = kwargs["candidate"] 311 for can in candidate: 312 perm = [axes.index(x) for x in can] 313 self.entities.append(ReorderEntity(perm)) 314 elif policy == 'interleave': 315 spatial, reduce = kwargs['spatial'], kwargs['reduce'] 316 317 spatial = [[axes.index(x) for x in ch] for ch in spatial] 318 reduce = [[axes.index(x) for x in ch] for ch in reduce] 319 320 outer_merged = self._merge_chain([x[:-1] for x in spatial]) 321 inner_merged = self._merge_chain([x[-1:] for x in spatial] + reduce) 322 323 for o in outer_merged: 324 for i in inner_merged: 325 self.entities.append(ReorderEntity(o + i)) 326 elif policy == 'interleave_cuda': 327 spatial, reduce = kwargs['spatial'], kwargs['reduce'] 328 329 spatial = [[axes.index(x) for x in ch] for ch in spatial] 330 reduce = [[axes.index(x) for x in ch] for ch in reduce] 331 332 outer_merged = self._merge_chain([x[:-1] for x in spatial]) 333 reduce_merged = self._merge_chain(reduce) 334 inner_merged = [x[-1] for x in spatial] 335 336 for o in outer_merged: 337 for r in reduce_merged: 338 self.entities.append(ReorderEntity(o + r + inner_merged)) 339 else: 340 raise RuntimeError("Invalid policy: " + policy) 341 342 @staticmethod 343 def get_num_output(axes, policy, **kwargs): 344 return len(axes) 345 346 def __repr__(self): 347 return "Reorder(policy=%s) len=%d" % (self.policy, len(self)) 348 349 def _merge_chain(self, chains): 350 """generate all combinations of merge some chains""" 351 merged = [] 352 tmp_pt = [0] * len(chains) 353 tmp_stack = [] 354 355 size = np.sum([len(x) for x in chains]) 356 self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged) 357 return merged 358 359 def _merge_dfs(self, chains, size, tmp_pt, tmp_stack, merged): 360 if np.sum(tmp_pt) == size: 361 merged.append(list(tmp_stack)) 362 return 363 364 for i in range(len(chains)): 365 # use i == np.argmax(....) here to take spatial order into consideration 366 # if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....) 367 if (tmp_pt[i] < len(chains[i]) and 368 (i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))]))): 369 tmp_stack.append(chains[i][tmp_pt[i]]) 370 tmp_pt[i] += 1 371 self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged) 372 tmp_pt[i] -= 1 373 tmp_stack.pop() 374 375 376class ReorderEntity(object): 377 """A reorder operation with detailed parameters that can apply to axes 378 379 Parameters 380 ---------- 381 perm: Array of int 382 define the permutation 383 """ 384 def __init__(self, perm): 385 self.perm = perm 386 387 def apply(self, sch, op, axes): 388 """Apply reorder to an array of axes 389 390 Parameters 391 ---------- 392 sch: tvm.schedule.Schedule 393 The tvm schedule 394 op: tvm.tensor.Operation 395 The stage to be applied 396 axis: tvm.schedule.IterVar 397 axis to split 398 399 Returns 400 ------- 401 axes : list of Axis 402 The transformed axes. 403 """ 404 if len(axes) == len(self.perm): 405 new_order = [axes[i] for i in self.perm] 406 else: 407 new_order = [axes[i] for i in self.perm if i < len(axes)] 408 sch[op].reorder(*new_order) 409 return new_order 410 411 def __repr__(self): 412 return str(self.perm) 413 414 415class AnnotateSpace(TransformSpace): 416 """The parameter space for annotating an array of axes""" 417 def __init__(self, axes, policy, **kwargs): 418 super(AnnotateSpace, self).__init__() 419 420 self.ins = axes 421 self.policy = policy 422 self.num_output = len(axes) 423 424 if policy == 'bind_gpu': 425 self.num_axis = len(axes) 426 if self.num_axis >= 6: 427 self.entities.append(AnnotateEntity( 428 ['fuse'] * (self.num_axis - 6) + 429 ['blockIdx.z', 'blockIdx.y', 'blockIdx.x', 430 'threadIdx.z', 'threadIdx.y', 'threadIdx.x'])) 431 elif self.num_axis >= 4: 432 self.entities.append(AnnotateEntity( 433 ['fuse'] * (self.num_axis - 4) + 434 ['blockIdx.y', 'blockIdx.x', 435 'threadIdx.y', 'threadIdx.x'])) 436 elif self.num_axis >= 2: 437 self.entities.append(AnnotateEntity( 438 ['fuse'] * (self.num_axis - 2) + 439 ['blockIdx.x', 'threadIdx.x'])) 440 else: 441 raise RuntimeError("Unhandled case in bind_gpu") 442 elif policy == 'bind_gpu_virtual': 443 self.num_axis = len(axes) 444 if self.num_axis >= 9: 445 self.entities.append(AnnotateEntity( 446 ['fuse'] * (self.num_axis - 9) + 447 ['blockIdx.z', 'blockIdx.y', 'blockIdx.x', 448 'vthread', 'vthread', 'vthread', 449 'threadIdx.z', 'threadIdx.y', 'threadIdx.x'])) 450 elif self.num_axis >= 6: 451 self.entities.append(AnnotateEntity( 452 ['fuse'] * (self.num_axis - 6) + 453 ['blockIdx.y', 'blockIdx.x', 454 'vthread', 'vthread', 455 'threadIdx.y', 'threadIdx.x'])) 456 elif self.num_axis >= 3: 457 self.entities.append(AnnotateEntity( 458 ['fuse'] * (self.num_axis - 3) + 459 ['blockIdx.x', 'vthread', 'threadIdx.x'])) 460 else: 461 raise RuntimeError("Unhandled case in bind_gpu") 462 elif policy == 'locate_cache': 463 self.num_axis = len(axes) 464 num_anchor = kwargs["num_anchor"] 465 self.anns = list(itertools.combinations(range(self.num_axis), num_anchor)) 466 self.entities = [AnnotateEntity(x) for x in self.anns] 467 else: # none, vec, unroll, try_vec, try_unroll, try_vec_unroll, ... 468 anns = policy.replace('try', 'none').split('_') 469 470 for ann in anns: 471 if ann not in ['none', 'unroll', 'vec']: 472 raise RuntimeError("Invalid policy: " + policy) 473 474 self.num_axis = len(axes) 475 self.anns = [anns] * self.num_axis 476 self._generate_space(0, [""] * self.num_axis) 477 478 def _generate_space(self, now, tmp_stack): 479 """Generate space by DFS""" 480 if now == self.num_axis: 481 # only vectorize inner most dimension 482 vec_ct = tmp_stack.count('vec') 483 if vec_ct in (0, 1): 484 self.entities.append(AnnotateEntity(list(tmp_stack))) 485 else: 486 for ann in self.anns[now]: 487 tmp_stack[now] = ann 488 self._generate_space(now + 1, tmp_stack) 489 490 @staticmethod 491 def get_num_output(axes, policy, **kwargs): 492 return len(axes) 493 494 def __repr__(self): 495 return "Annotate(policy=%s) len=%d" % (self.policy, len(self)) 496 497 498class AnnotateEntity(object): 499 """An annotation operation with detailed parameters that can apply to axes 500 501 Parameters 502 ---------- 503 anns: Array of string 504 The annotations of axes 505 """ 506 def __init__(self, anns): 507 self.anns = anns 508 509 def apply(self, sch, op, axes, axis_lens=None, 510 max_unroll=None, vec_size=None, cfg=None, source=None): 511 """Apply annotation to an array of axes 512 513 Parameters 514 ---------- 515 sch: tvm.schedule.Schedule 516 The tvm schedule 517 op: tvm.tensor.Operation 518 The stage to be applied 519 axes: Array of tvm.schedule.IterVar 520 axis to split 521 axis_lens: Array of int, optional 522 the length of axes 523 max_unroll: int, optional 524 maximum unroll step 525 vec_size: Array of int, optional 526 valid vector lanes for vectorization 527 cfg: ConfigEntity, optional 528 cfg for recording error 529 source: Array of Array tensor, optional 530 source tensor for attaching cache 531 532 Returns 533 ------- 534 axes : list of tvm.schedule.IterVar 535 The transformed axes 536 """ 537 if source is not None: # special case : attach cache_read/cache_write 538 for src, to in zip(source, self.anns): 539 for t in src: 540 sch[t].compute_at(sch[op], axes[to]) 541 else: # other cases 542 for i, ann in enumerate(self.anns): 543 if ann == 'none': 544 pass 545 elif ann == 'unroll': 546 if max_unroll and axis_lens[i] > max_unroll: 547 cfg.raise_error("Too large factor for unrolling") 548 sch[op].unroll(axes[i]) 549 elif ann == 'vec': 550 if vec_size and axis_lens[i] not in vec_size: 551 cfg.raise_error("Wrong size of lanes in vectorization") 552 sch[op].vectorize(axes[i]) 553 elif ann == 'blockIdx.x': 554 sch[op].bind(axes[i], thread_axis('blockIdx.x')) 555 elif ann == 'blockIdx.y': 556 sch[op].bind(axes[i], thread_axis('blockIdx.y')) 557 elif ann == 'blockIdx.z': 558 sch[op].bind(axes[i], thread_axis('blockIdx.z')) 559 elif ann == 'threadIdx.x': 560 sch[op].bind(axes[i], thread_axis('threadIdx.x')) 561 elif ann == 'threadIdx.y': 562 sch[op].bind(axes[i], thread_axis('threadIdx.y')) 563 elif ann == 'threadIdx.z': 564 sch[op].bind(axes[i], thread_axis('threadIdx.z')) 565 elif ann == 'vthread': 566 sch[op].bind(axes[i], thread_axis("vthread")) 567 elif ann == 'fuse': 568 assert i < len(axes) - 1 569 axes[i+1] = sch[op].fuse(axes[i], axes[i+1]) 570 else: 571 raise RuntimeError("Invalid annotation " + ann) 572 return axes 573 574 def __repr__(self): 575 return str(self.anns) 576 577 578class OtherOptionSpace(TransformSpace): 579 """The parameter space for general option""" 580 def __init__(self, axes, policy, **kwargs): 581 super(OtherOptionSpace, self).__init__() 582 583 candidate = kwargs["candidate"] 584 self.entities = [OtherOptionEntity(x) for x in candidate] 585 586 @staticmethod 587 def get_num_output(axes, policy, **kwargs): 588 return 0 589 590 def __repr__(self): 591 return "OtherOption(%s) len=%d" % (self.entities, len(self)) 592 593 594class OtherOptionEntity(object): 595 """The parameter entity for general option, with a detailed value""" 596 def __init__(self, val): 597 self.val = val 598 599 def __repr__(self): 600 return str(self.val) 601 602 603class ConfigSpace(object): 604 """The configuration space of a schedule. Pass it as config in template to 605 collect transformation space and build transform graph of axes 606 """ 607 def __init__(self): 608 # private dict to provide sugar 609 self.space_map = OrderedDict() # name -> space 610 self._collect = True 611 self._length = None 612 self._entity_map = OrderedDict() # name -> entity 613 self._constraints = [] 614 self.errors = [] 615 self.template_key = None 616 self.code_hash = None 617 self.flop = 0 618 self.is_fallback = False 619 620 @staticmethod 621 def axis(var): 622 """get a virtual axis (axis placeholder) 623 624 Parameters 625 ---------- 626 var: int or tvm.schedule.IterVar 627 If is int, return an axis whose length is the provided argument. 628 If is IterVar, return an axis whose length is extracted from the 629 IterVar's extent domain. 630 """ 631 return VirtualAxis(var) 632 633 reduce_axis = axis 634 635 def define_split(self, name, axis, policy='factors', **kwargs): 636 """Define a new tunable knob which splits an axis into a list of axes 637 638 Parameters 639 ---------- 640 name: str 641 name to index the entity of this space 642 axis: tvm.schedule.IterVar 643 axis to split 644 policy: str 645 name of policy. 646 If is 'factors', the tuner will try all divisible factors. 647 If is 'power2', the tuner will try power-of-two factors less or equal to the length. 648 If is 'verbose', the tuner will try all candidates in above two policies. 649 If is 'candidate', try given candidates. 650 kwargs: dict 651 extra arguments for policy 652 max_factor: int 653 the maximum split factor. 654 filter: function(int) -> bool 655 see examples below for how to use filter. 656 num_outputs: int 657 the total number of axis after split. 658 no_tail: bool 659 should we only include divisible numbers as split factors. 660 candidate: list 661 (policy=candidate) manual candidate list. 662 663 Examples 664 -------- 665 >>> # use custom candidates 666 >>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]]) 667 668 >>> # use a filter that only accepts the split scheme whose inner most tile is less then 4 669 >>> cfg.define_split('tile_y', y, policy='factors', filter=lambda x: x.size[-1] <= 4) 670 """ 671 axes = [axis] 672 return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs) 673 674 def define_reorder(self, name, axes, policy, **kwargs): 675 """Define a new tunable knob which reorders a list of axes 676 677 Parameters 678 ---------- 679 name: str 680 name to index the entity of this space 681 axes: Array of tvm.schedule.IterVar 682 axes to reorder 683 policy: str 684 name of policy 685 If is 'identity', do an identity permutation. 686 If is 'all', try all permutations. 687 If is 'interval_all', try all permutations of an interval of axes. 688 If is 'candidate', try listed candidate. 689 If is 'interleave', interleave chains of spatial axes and chains of reduction axes. 690 kwargs: dict 691 extra arguments for policy 692 """ 693 return self._add_new_transform(ReorderSpace, name, axes, policy, **kwargs) 694 695 def define_annotate(self, name, axes, policy, **kwargs): 696 """Define a new tunable knob which annotates a list of axes 697 698 Parameters 699 ---------- 700 name: str 701 name to index the entity of this space 702 axes: Array of tvm.schedule.IterVar 703 axes to annotate 704 policy: str 705 name of policy 706 If is 'unroll', unroll the axes. 707 If is 'try_unroll', try to unroll the axes. 708 If is 'try_unroll_vec', try to unroll or vectorize the axes. 709 If is 'bind_gpu', bind the first few axes to gpu threads. 710 If is 'locate_cache', choose n axes to attach shared/local cache. 711 kwargs: dict 712 extra arguments for policy 713 """ 714 return self._add_new_transform(AnnotateSpace, name, axes, policy, **kwargs) 715 716 def define_knob(self, name, candidate): 717 """Define a tunable knob with a list of candidates 718 719 Parameters 720 ---------- 721 name: str 722 name key of that option 723 candidate: list 724 list of candidates 725 """ 726 return self._add_new_transform(OtherOptionSpace, name, [], None, candidate=candidate) 727 728 def add_flop(self, flop): 729 """Add float operation statistics for this tuning task 730 731 Parameters 732 --------- 733 flop: int or float 734 number of float operations 735 """ 736 self.flop += flop 737 738 def raise_error(self, msg): 739 """register error in config 740 Using this to actively detect error when scheudling. 741 Otherwise these error will occur during runtime, which 742 will cost more time. 743 744 Parameters 745 ---------- 746 msg: str 747 """ 748 self.errors.append(msg) 749 750 def valid(self): 751 """Check whether the config meets all the constraints 752 Note: This check should be called after instantiation of task, 753 because the ConfigEntity/ConfigSpace collects errors during instantiation 754 755 Returns 756 ------- 757 valid: bool 758 whether the config meets all the constraints 759 """ 760 return not bool(self.errors) 761 762 def _add_new_transform(self, space_class, name, axes, policy, **kwargs): 763 """Add a new transform space in template""" 764 if self._collect: 765 # convert schedule axis to space definition axis 766 axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) for x in axes] 767 768 # add subspace (knob) 769 space = space_class(axes, policy, **kwargs) 770 self.space_map[name] = space 771 self._entity_map[name] = space[0] 772 return [Axis(space, i) for i in range(space.num_output)] 773 return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))] 774 775 def __len__(self): 776 if self._length is None: 777 self._length = int(np.prod([len(x) for x in self.space_map.values()])) 778 return self._length 779 780 def get(self, index): 781 """Get a config entity with detailed parameters from this space 782 783 Parameters 784 ---------- 785 index: int 786 index in the space 787 """ 788 entities = OrderedDict() 789 t = index 790 for name, space in self.space_map.items(): 791 entities[name] = space[t % len(space)] 792 t //= len(space) 793 ret = ConfigEntity(index, self.code_hash, self.template_key, entities, self._constraints) 794 return ret 795 796 def __iter__(self): 797 return self._entity_map.__iter__() 798 799 def __getitem__(self, name): 800 """get the transform entity(knob) of this entity by name 801 do not use this to get a ConfigEntity of this space (should use ConfigSpace.get instead) 802 803 Parameters 804 ---------- 805 name: str 806 name of the transform 807 """ 808 return self._entity_map[name] 809 810 def __repr__(self): 811 res = "ConfigSpace (len=%d, space_map=\n" % len(self) 812 for i, (name, space) in enumerate(self.space_map.items()): 813 res += " %2d %s: %s\n" % (i, name, space) 814 return res + ")" 815 816 817_ann_to_number = { 818 'none': 0, 'vec': 1, 'unroll': 2, 819 'blockIdx.x': 3, 'blockIdx.y': 4, 'blockIdx.z': 5, 820 'threadIdx.x': 6, 'threadIdx.y': 7, 'threadIdx.z': 8, 821 'vthread': 9, 'fuse': 10 822} 823 824class ConfigEntity(ConfigSpace): 825 """A configuration with detailed parameters 826 827 Parameters 828 ---------- 829 index: int 830 index of this config in space 831 code_hash: str 832 hash of schedule code 833 template_key : str 834 The specific template key 835 entity_map: dict 836 map name to transform entity 837 constraints : list 838 List of constraints 839 """ 840 def __init__(self, index, code_hash, template_key, entity_map, constraints): 841 super(ConfigEntity, self).__init__() 842 self.index = index 843 self.template_key = template_key 844 self._collect = False 845 self._entity_map = entity_map 846 self._space_map = None 847 self._constraints = constraints 848 self.code_hash = code_hash 849 850 def get_flatten_feature(self): 851 """ flatten entities to a numerical one-dimensional feature vector 852 853 Returns 854 ------- 855 fea: np.array 856 one dimensional float32 array 857 """ 858 fea = [] 859 for _, v in self._entity_map.items(): 860 if isinstance(v, SplitEntity): 861 fea.extend(v.size) 862 elif isinstance(v, ReorderEntity): 863 # use a naive way: directly copy the permutation 864 fea.extend(v.perm) 865 elif isinstance(v, AnnotateEntity): 866 # one-hot encoding 867 for ann in v.anns: 868 tmp = [0] * len(_ann_to_number) 869 tmp[_ann_to_number[ann]] = 1 870 fea.extend(tmp) 871 elif isinstance(v, OtherOptionEntity): 872 fea.append(v.val) 873 return np.array(fea, dtype=np.float32) 874 875 def get_other_option(self): 876 """ 877 Returns 878 ------- 879 other_option: dict 880 other tunable parameters (tunable parameters defined by `cfg.define_knob`) 881 """ 882 return {x: x.val for x in self._entity_map.values() if isinstance(x, OtherOptionEntity)} 883 884 def to_json_dict(self): 885 """convert to a json serializable dictionary 886 887 Return 888 ------ 889 json_dict: dict 890 a json serializable dictionary 891 """ 892 ret = {} 893 ret['i'] = int(self.index) 894 ret['t'] = self.template_key 895 ret['c'] = self.code_hash 896 entity_map = [] 897 for k, v in self._entity_map.items(): 898 if isinstance(v, SplitEntity): 899 entity_map.append((k, 'sp', v.size)) 900 elif isinstance(v, ReorderEntity): 901 entity_map.append((k, 're', v.perm)) 902 elif isinstance(v, AnnotateEntity): 903 entity_map.append((k, 'an', v.anns)) 904 elif isinstance(v, OtherOptionEntity): 905 entity_map.append((k, 'ot', v.val)) 906 else: 907 raise RuntimeError("Invalid entity instance: " + v) 908 ret['e'] = entity_map 909 return ret 910 911 @staticmethod 912 def from_json_dict(json_dict): 913 """Build a ConfigEntity from json serializable dictionary 914 915 Parameters 916 ---------- 917 json_dict: dict 918 Json serializable dictionary. This should be the return value 919 of :any:`to_json_dict`. 920 921 Returns 922 ------- 923 config: ConfigEntity 924 The corresponding config object 925 926 """ 927 index = json_dict["i"] 928 code_hash = json_dict["c"] 929 template_key = json_dict["t"] 930 constraints = [] 931 entity_map = OrderedDict() 932 933 for item in json_dict["e"]: 934 key, knob_type, knob_args = item 935 if knob_type == 'sp': 936 entity = SplitEntity(knob_args) 937 elif knob_type == 're': 938 entity = ReorderEntity(knob_args) 939 elif knob_type == 'an': 940 entity = AnnotateEntity(knob_args) 941 elif knob_type == 'ot': 942 entity = OtherOptionEntity(knob_args) 943 else: 944 raise RuntimeError("Invalid config knob type: " + knob_type) 945 entity_map[str(key)] = entity 946 947 return ConfigEntity(index, code_hash, template_key, entity_map, constraints) 948 949 def __repr__(self): 950 return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key, 951 self.code_hash, self.index) 952 953 954class FallbackConfigEntity(ConfigSpace): 955 """The config entity created to support fallback""" 956 957 def __init__(self): 958 super(FallbackConfigEntity, self).__init__() 959 self.is_fallback = True 960 961 def fallback_split(self, name, constraints): 962 """Fallback a split knob 963 964 Parameters 965 ---------- 966 name: str 967 name of the knob 968 constraints: List of int 969 The maximum tile size for every dimension. Value `-1` means no constraint. 970 971 Examples 972 -------- 973 If you use cfg.define_split('tile_0', 128, num_outputs=3), 974 Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [4, 8, 4] 975 976 If you use cfg.define_split('tile_0', 49, num_outputs=3), 977 Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1] 978 """ 979 space = self.space_map[name] 980 assert isinstance(space, SplitSpace) 981 assert len(constraints) == space.num_output 982 983 # '-1' means no constraint 984 constraints = [x if x != -1 else 1e10 for x in constraints] 985 986 entity = self._entity_map[name] 987 now = space.product 988 989 for i in reversed(range(space.num_output)): 990 factors = get_factors(now) 991 992 find = len(factors) - 1 993 for j, f in enumerate(factors): 994 if f > constraints[i]: 995 find = j - 1 996 break 997 998 if find >= 0: 999 entity.size[i] = factors[find] 1000 now //= factors[find] 1001 else: 1002 raise RuntimeError("Cannot find feasible fallback split entity for node: " + name) 1003 1004 def fallback_with_reference_log(self, ref_log): 1005 """A data driven fallback mechanism. 1006 We use tuned parameters from TopHub as reference data. 1007 For an unseen shape, we find the most similar tuned one from TopHub and 1008 mimic its parameters. 1009 Note that we are not matching by workload (e.g., input size, kernel size), 1010 but instead matching by configuration space. The idea is that if two workloads have 1011 similar configuration space, their optimal configurations are also likely to be similar. 1012 1013 Parameters 1014 ---------- 1015 ref_log: List of (MeasureInput, MeasureResult) 1016 The reference log 1017 """ 1018 knob_names = [x for x in self.space_map.keys() if 1019 isinstance(self.space_map[x], SplitSpace)] 1020 1021 # find best match config in reference data by matching tiling factors 1022 factor_list = [] 1023 for knob_name in knob_names: 1024 factor_list.append(get_factors(self.space_map[knob_name].product)) 1025 1026 best_match_cfg = None 1027 best_match_score = 0 1028 for inp, _ in ref_log: 1029 match_score = 0 1030 for i, knob_name in enumerate(knob_names): 1031 factors = get_factors(int(np.prod(inp.config[knob_name].size))) 1032 match_score += (float(len(set(factor_list[i]).intersection(factors))) / 1033 len(factor_list[i])) 1034 1035 if match_score > best_match_score: 1036 best_match_score, best_match_cfg = match_score, inp.config 1037 1038 if best_match_cfg is None: 1039 return 1040 1041 # mimic its tiling strategy 1042 for knob_name in knob_names: 1043 constraint = list(best_match_cfg[knob_name].size) 1044 constraint[0] = -1 1045 self.fallback_split(knob_name, constraint) 1046 1047 # copy other knobs 1048 for knob_name in self.space_map.keys(): 1049 if not isinstance(self.space_map[knob_name], SplitSpace): 1050 self._entity_map[knob_name] = best_match_cfg[knob_name] 1051 1052 def __setitem__(self, name, entity): 1053 """set the entity(knob) of by name 1054 1055 Parameters 1056 ---------- 1057 name: str 1058 name of the entity 1059 entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity 1060 value of the entity 1061 """ 1062 self._entity_map[name] = entity 1063 1064 def __repr__(self): 1065 return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash) 1066