1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
18# pylint: disable=consider-using-enumerate,too-many-lines
19"""
20Template configuration space.
21
22Each template function can be parametrized by a ConfigSpace.
23The space is declared when we invoke the template function with ConfigSpace.
24During evaluation, we pass in a ConfigEntity, which contains a specific
25entity in the space. This entity contains deterministic parameters.
26"""
27from __future__ import absolute_import as _abs
28
29import itertools
30import functools
31import math
32from collections import namedtuple, OrderedDict
33import numpy as np
34
35from tvm import schedule, thread_axis
36from tvm.autotvm.util import get_const_int
37
38Axis = namedtuple('Axis', ['space', 'index'])
39
40try:
41    _long = long
42except NameError:
43    _long = int
44
45
46class InstantiationError(ValueError):
47    """Actively detected error in instantiating a template with a config,
48     raised by cfg.raise_error
49     e.g. too many unrolling, too many threads in a block
50    """
51
52
53class TransformSpace(object):
54    """Base class for transform space
55    TransformSpace is the node in the computation graph of axes
56
57    Note
58    ----
59    We can regard our schedule code as a transformation graph of axes.
60    Starting from raw axes in the definition of tvm.compute, we can transform these axes
61    by some operators. The operator includes 'split', 'reorder' and 'annotate'.
62    Each operator has some tunable parameters (e.g. the split factor).
63    Then the tuning process is just to find good parameters of these op.
64
65    So the all the combinations of the parameters of these op forms our search space.
66
67    Naming convention:
68    We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...)
69    We call a specific entity in a space as XXXEntity.
70    """
71    def __init__(self):
72        self.ins = []
73        self.num_output = 0
74        self.entities = []
75
76    def __len__(self):
77        return len(self.entities)
78
79    def __getitem__(self, index):
80        """Get an entity of the space by index
81
82        Parameters
83        ----------
84        index: int
85
86        Returns
87        -------
88        transform entity
89        """
90        return self.entities[index]
91
92    @staticmethod
93    def get_num_output():
94        """get number of output axes after this transform
95
96        Returns
97        -------
98        n: int
99            number of output axes
100        """
101        return 0
102
103
104class VirtualAxis(TransformSpace):
105    """Axis placeholder in template
106
107    Parameters
108    ----------
109    var: int or tvm.schedule.IterVar
110        If is int, return a virtual axis whose length is the provided argument.
111        If is IterVar, return a virtual axis whose length is extracted from
112                       the IterVar's extent domain.
113    name: str
114    """
115    name_ct = 0
116
117    def __init__(self, var, name=None):
118        super(VirtualAxis, self).__init__()
119        self.num_output = 1
120
121        if name is None:
122            name = 'axis_%d' % VirtualAxis.name_ct
123            VirtualAxis.name_ct += 1
124
125        self.name = name
126        if isinstance(var, (int, _long)):
127            self.length = var
128        elif isinstance(var, schedule.IterVar):
129            self.name = var.var.name
130            if var.dom is None:
131                self.length = -1
132            else:
133                self.length = get_const_int(var.dom.extent)
134        elif isinstance(var, VirtualAxis):
135            self.length = var.length
136        else:
137            raise RuntimeError("Invalid type of axis: " + str(type(var)))
138
139    @staticmethod
140    def get_num_output(var, name=None):
141        return 1
142
143    def __repr__(self):
144        return "vaxis(%s)" % self.name
145
146
147def get_factors(n):
148    """return all factors of an integer
149
150    Parameters
151    ----------
152    n: int
153        integer to factorize
154
155    Returns
156    -------
157    factors: list
158        List of all factors
159    """
160    step = 2 if n % 2 else 1
161    ret = list(set(
162        functools.reduce(
163            list.__add__, ([i, n//i] for i in range(1, int(math.sqrt(n)) + 1, step)
164                           if n % i == 0))))
165    ret.sort()
166    return ret
167
168def get_pow2s(n):
169    """return all power-of-two numbers that are less or equal than the integer
170
171    Parameters
172    ----------
173    n: int
174        integer for reference
175
176    Returns
177    -------
178    factors: list
179        List of all power-of-two numbers
180    """
181    return [2**x for x in range(math.floor(math.log2(n)) + 1)]
182
183class SplitSpace(TransformSpace):
184    """Split an axis for several times"""
185    def __init__(self, axes, policy, **kwargs):
186        super(SplitSpace, self).__init__()
187        axis = axes[0]
188
189        self.policy = policy
190        self.entities = []
191
192        max_factor = kwargs.get("max_factor", 1 << 31)
193        fil = kwargs.get("filter", lambda x: True)
194        self.product = axis.length
195        self.num_output = kwargs.get("num_outputs", 0)
196        assert self.num_output > 0
197
198        if policy == 'candidate':
199            for size in kwargs["candidate"]:
200                assert len(size) == self.num_output
201                self.entities.append(SplitEntity(size))
202        else:
203            if policy == 'verbose':
204                # Include factors and power-of-twos. May generate tails.
205                divisibles = get_factors(self.product)
206                pow2s = get_pow2s(self.product)
207                factors = [x for x in list(set(divisibles) | set(pow2s)) if x <= max_factor]
208            elif policy == 'factors':
209                # Include divisible factors. Guarantee no tails.
210                factors = [x for x in get_factors(self.product) if x <= max_factor]
211            elif policy == 'power2':
212                # Include less, equal, and round-up power-of-two numbers. May generate tails.
213                factors = [x for x in get_pow2s(self.product) if x <= max_factor]
214            else:
215                raise RuntimeError("Invalid policy: %s" % policy)
216
217            # Enforce the product of all split factors equals to the axis length
218            no_tail = kwargs.get("no_tail", policy == 'factors')
219
220            # Generate split entity by enumerating candidate factors.
221            self.factors = factors
222            self._generate_space(0, [None] * (self.num_output - 1), enforce_no_tail=no_tail)
223
224        self.entities = list(filter(fil, self.entities))
225
226    def _generate_space(self, now, tmp_stack, enforce_no_tail=False):
227        """Generate space by DFS"""
228        if now == self.num_output - 1:
229            prod = functools.reduce(lambda x, y: x * y, tmp_stack)
230            if prod > self.product:
231                return
232            if self.product % prod == 0 or (not enforce_no_tail and prod < self.product):
233                self.entities.append(SplitEntity([-1] + tmp_stack[::-1]))
234        else:
235            for factor in self.factors:
236                tmp_stack[now] = factor
237                self._generate_space(now + 1, tmp_stack, enforce_no_tail)
238
239    @staticmethod
240    def get_num_output(axes, policy, **kwargs):
241        return kwargs["num_outputs"]
242
243    def __repr__(self):
244        return ("Split(policy=%s, product=%d, num_outputs=%d) len=%d" %
245                (self.policy, self.product, self.num_output, len(self)))
246
247
248class SplitEntity(object):
249    """
250    A split operation with detailed parameters
251    that can apply to an axis
252
253    Parameters
254    ----------
255    size: Array of int
256        the size of every axis after split
257        e.g. an axis of extent 128, we split it into 3 axes, a possible
258             size is [4, 4, 8] (4x4x8 = 128)
259    """
260    def __init__(self, size):
261        self.size = size
262
263    def apply(self, sch, op, axis):
264        """Apply split to an axis
265
266        Parameters
267        ----------
268        sch: tvm.schedule.Schedule
269            The tvm schedule
270        op: tvm.tensor.Operation
271            The stage to be applied
272        axis: tvm.schedule.IterVar
273            axis to split
274
275        Returns
276        -------
277        axes : list of Axis
278            The transformed axes.
279        """
280        ret = []
281        for i in range(1, len(self.size)):
282            ax0, ax1 = sch[op].split(axis, int(np.prod(self.size[i:])))
283            ret.append(ax0)
284            axis = ax1
285        return ret + [axis]
286
287    def __repr__(self):
288        return str(self.size)
289
290
291class ReorderSpace(TransformSpace):
292    """The parameter space for ordering an array of axes"""
293    def __init__(self, axes, policy, **kwargs):
294        super(ReorderSpace, self).__init__()
295        self.ins = axes
296        self.policy = policy
297        self.num_output = len(axes)
298
299        if policy == 'identity':
300            self.entities = [ReorderEntity(range(len(axes)))]
301        elif policy == 'all':
302            self.entities = [
303                ReorderEntity(x) for x in itertools.permutations(range(len(axes)))]
304        elif policy == 'interval_all':
305            begin, end = kwargs['interval']
306            sub_space = list(itertools.permutations(range(begin, end)))
307            prefix, suffix = tuple(range(begin)), tuple(range(end, len(axes)))
308            self.entities = [ReorderEntity(prefix + x + suffix) for x in sub_space]
309        elif policy == 'candidate':
310            candidate = kwargs["candidate"]
311            for can in candidate:
312                perm = [axes.index(x) for x in can]
313                self.entities.append(ReorderEntity(perm))
314        elif policy == 'interleave':
315            spatial, reduce = kwargs['spatial'], kwargs['reduce']
316
317            spatial = [[axes.index(x) for x in ch] for ch in spatial]
318            reduce = [[axes.index(x) for x in ch] for ch in reduce]
319
320            outer_merged = self._merge_chain([x[:-1] for x in spatial])
321            inner_merged = self._merge_chain([x[-1:] for x in spatial] + reduce)
322
323            for o in outer_merged:
324                for i in inner_merged:
325                    self.entities.append(ReorderEntity(o + i))
326        elif policy == 'interleave_cuda':
327            spatial, reduce = kwargs['spatial'], kwargs['reduce']
328
329            spatial = [[axes.index(x) for x in ch] for ch in spatial]
330            reduce = [[axes.index(x) for x in ch] for ch in reduce]
331
332            outer_merged = self._merge_chain([x[:-1] for x in spatial])
333            reduce_merged = self._merge_chain(reduce)
334            inner_merged = [x[-1] for x in spatial]
335
336            for o in outer_merged:
337                for r in reduce_merged:
338                    self.entities.append(ReorderEntity(o + r + inner_merged))
339        else:
340            raise RuntimeError("Invalid policy: " + policy)
341
342    @staticmethod
343    def get_num_output(axes, policy, **kwargs):
344        return len(axes)
345
346    def __repr__(self):
347        return "Reorder(policy=%s) len=%d" % (self.policy, len(self))
348
349    def _merge_chain(self, chains):
350        """generate all combinations of merge some chains"""
351        merged = []
352        tmp_pt = [0] * len(chains)
353        tmp_stack = []
354
355        size = np.sum([len(x) for x in chains])
356        self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
357        return merged
358
359    def _merge_dfs(self, chains, size, tmp_pt, tmp_stack, merged):
360        if np.sum(tmp_pt) == size:
361            merged.append(list(tmp_stack))
362            return
363
364        for i in range(len(chains)):
365            # use i == np.argmax(....) here to take spatial order into consideration
366            # if we don't want to consider spatial order, we can use tmp_pt[i] == np.max(....)
367            if (tmp_pt[i] < len(chains[i]) and
368                    (i == np.argmax([len(chains[x]) - tmp_pt[x] for x in range(len(chains))]))):
369                tmp_stack.append(chains[i][tmp_pt[i]])
370                tmp_pt[i] += 1
371                self._merge_dfs(chains, size, tmp_pt, tmp_stack, merged)
372                tmp_pt[i] -= 1
373                tmp_stack.pop()
374
375
376class ReorderEntity(object):
377    """A reorder operation with detailed parameters that can apply to axes
378
379    Parameters
380    ----------
381    perm: Array of int
382        define the permutation
383    """
384    def __init__(self, perm):
385        self.perm = perm
386
387    def apply(self, sch, op, axes):
388        """Apply reorder to an array of axes
389
390        Parameters
391        ----------
392        sch: tvm.schedule.Schedule
393            The tvm schedule
394        op: tvm.tensor.Operation
395            The stage to be applied
396        axis: tvm.schedule.IterVar
397            axis to split
398
399        Returns
400        -------
401        axes : list of Axis
402            The transformed axes.
403        """
404        if len(axes) == len(self.perm):
405            new_order = [axes[i] for i in self.perm]
406        else:
407            new_order = [axes[i] for i in self.perm if i < len(axes)]
408        sch[op].reorder(*new_order)
409        return new_order
410
411    def __repr__(self):
412        return str(self.perm)
413
414
415class AnnotateSpace(TransformSpace):
416    """The parameter space for annotating an array of axes"""
417    def __init__(self, axes, policy, **kwargs):
418        super(AnnotateSpace, self).__init__()
419
420        self.ins = axes
421        self.policy = policy
422        self.num_output = len(axes)
423
424        if policy == 'bind_gpu':
425            self.num_axis = len(axes)
426            if self.num_axis >= 6:
427                self.entities.append(AnnotateEntity(
428                    ['fuse'] * (self.num_axis - 6) +
429                    ['blockIdx.z', 'blockIdx.y', 'blockIdx.x',
430                     'threadIdx.z', 'threadIdx.y', 'threadIdx.x']))
431            elif self.num_axis >= 4:
432                self.entities.append(AnnotateEntity(
433                    ['fuse'] * (self.num_axis - 4) +
434                    ['blockIdx.y', 'blockIdx.x',
435                     'threadIdx.y', 'threadIdx.x']))
436            elif self.num_axis >= 2:
437                self.entities.append(AnnotateEntity(
438                    ['fuse'] * (self.num_axis - 2) +
439                    ['blockIdx.x', 'threadIdx.x']))
440            else:
441                raise RuntimeError("Unhandled case in bind_gpu")
442        elif policy == 'bind_gpu_virtual':
443            self.num_axis = len(axes)
444            if self.num_axis >= 9:
445                self.entities.append(AnnotateEntity(
446                    ['fuse'] * (self.num_axis - 9) +
447                    ['blockIdx.z', 'blockIdx.y', 'blockIdx.x',
448                     'vthread', 'vthread', 'vthread',
449                     'threadIdx.z', 'threadIdx.y', 'threadIdx.x']))
450            elif self.num_axis >= 6:
451                self.entities.append(AnnotateEntity(
452                    ['fuse'] * (self.num_axis - 6) +
453                    ['blockIdx.y', 'blockIdx.x',
454                     'vthread', 'vthread',
455                     'threadIdx.y', 'threadIdx.x']))
456            elif self.num_axis >= 3:
457                self.entities.append(AnnotateEntity(
458                    ['fuse'] * (self.num_axis - 3) +
459                    ['blockIdx.x', 'vthread', 'threadIdx.x']))
460            else:
461                raise RuntimeError("Unhandled case in bind_gpu")
462        elif policy == 'locate_cache':
463            self.num_axis = len(axes)
464            num_anchor = kwargs["num_anchor"]
465            self.anns = list(itertools.combinations(range(self.num_axis), num_anchor))
466            self.entities = [AnnotateEntity(x) for x in self.anns]
467        else:  # none, vec, unroll, try_vec, try_unroll, try_vec_unroll, ...
468            anns = policy.replace('try', 'none').split('_')
469
470            for ann in anns:
471                if ann not in ['none', 'unroll', 'vec']:
472                    raise RuntimeError("Invalid policy: " + policy)
473
474            self.num_axis = len(axes)
475            self.anns = [anns] * self.num_axis
476            self._generate_space(0, [""] * self.num_axis)
477
478    def _generate_space(self, now, tmp_stack):
479        """Generate space by DFS"""
480        if now == self.num_axis:
481            # only vectorize inner most dimension
482            vec_ct = tmp_stack.count('vec')
483            if vec_ct in (0, 1):
484                self.entities.append(AnnotateEntity(list(tmp_stack)))
485        else:
486            for ann in self.anns[now]:
487                tmp_stack[now] = ann
488                self._generate_space(now + 1, tmp_stack)
489
490    @staticmethod
491    def get_num_output(axes, policy, **kwargs):
492        return len(axes)
493
494    def __repr__(self):
495        return "Annotate(policy=%s) len=%d" % (self.policy, len(self))
496
497
498class AnnotateEntity(object):
499    """An annotation operation with detailed parameters that can apply to axes
500
501    Parameters
502    ----------
503    anns: Array of string
504        The annotations of axes
505    """
506    def __init__(self, anns):
507        self.anns = anns
508
509    def apply(self, sch, op, axes, axis_lens=None,
510              max_unroll=None, vec_size=None, cfg=None, source=None):
511        """Apply annotation to an array of axes
512
513        Parameters
514        ----------
515        sch: tvm.schedule.Schedule
516            The tvm schedule
517        op: tvm.tensor.Operation
518            The stage to be applied
519        axes: Array of tvm.schedule.IterVar
520            axis to split
521        axis_lens: Array of int, optional
522            the length of axes
523        max_unroll: int, optional
524            maximum unroll step
525        vec_size: Array of int, optional
526            valid vector lanes for vectorization
527        cfg: ConfigEntity, optional
528            cfg for recording error
529        source: Array of Array tensor, optional
530            source tensor for attaching cache
531
532        Returns
533        -------
534        axes : list of tvm.schedule.IterVar
535            The transformed axes
536        """
537        if source is not None:  # special case : attach cache_read/cache_write
538            for src, to in zip(source, self.anns):
539                for t in src:
540                    sch[t].compute_at(sch[op], axes[to])
541        else:  # other cases
542            for i, ann in enumerate(self.anns):
543                if ann == 'none':
544                    pass
545                elif ann == 'unroll':
546                    if max_unroll and axis_lens[i] > max_unroll:
547                        cfg.raise_error("Too large factor for unrolling")
548                    sch[op].unroll(axes[i])
549                elif ann == 'vec':
550                    if vec_size and axis_lens[i] not in vec_size:
551                        cfg.raise_error("Wrong size of lanes in vectorization")
552                    sch[op].vectorize(axes[i])
553                elif ann == 'blockIdx.x':
554                    sch[op].bind(axes[i], thread_axis('blockIdx.x'))
555                elif ann == 'blockIdx.y':
556                    sch[op].bind(axes[i], thread_axis('blockIdx.y'))
557                elif ann == 'blockIdx.z':
558                    sch[op].bind(axes[i], thread_axis('blockIdx.z'))
559                elif ann == 'threadIdx.x':
560                    sch[op].bind(axes[i], thread_axis('threadIdx.x'))
561                elif ann == 'threadIdx.y':
562                    sch[op].bind(axes[i], thread_axis('threadIdx.y'))
563                elif ann == 'threadIdx.z':
564                    sch[op].bind(axes[i], thread_axis('threadIdx.z'))
565                elif ann == 'vthread':
566                    sch[op].bind(axes[i], thread_axis("vthread"))
567                elif ann == 'fuse':
568                    assert i < len(axes) - 1
569                    axes[i+1] = sch[op].fuse(axes[i], axes[i+1])
570                else:
571                    raise RuntimeError("Invalid annotation " + ann)
572        return axes
573
574    def __repr__(self):
575        return str(self.anns)
576
577
578class OtherOptionSpace(TransformSpace):
579    """The parameter space for general option"""
580    def __init__(self, axes, policy, **kwargs):
581        super(OtherOptionSpace, self).__init__()
582
583        candidate = kwargs["candidate"]
584        self.entities = [OtherOptionEntity(x) for x in candidate]
585
586    @staticmethod
587    def get_num_output(axes, policy, **kwargs):
588        return 0
589
590    def __repr__(self):
591        return "OtherOption(%s) len=%d" % (self.entities, len(self))
592
593
594class OtherOptionEntity(object):
595    """The parameter entity for general option, with a detailed value"""
596    def __init__(self, val):
597        self.val = val
598
599    def __repr__(self):
600        return str(self.val)
601
602
603class ConfigSpace(object):
604    """The configuration space of a schedule. Pass it as config in template to
605       collect transformation space and build transform graph of axes
606    """
607    def __init__(self):
608        # private dict to provide sugar
609        self.space_map = OrderedDict()    # name -> space
610        self._collect = True
611        self._length = None
612        self._entity_map = OrderedDict()  # name -> entity
613        self._constraints = []
614        self.errors = []
615        self.template_key = None
616        self.code_hash = None
617        self.flop = 0
618        self.is_fallback = False
619
620    @staticmethod
621    def axis(var):
622        """get a virtual axis (axis placeholder)
623
624        Parameters
625        ----------
626        var: int or tvm.schedule.IterVar
627            If is int, return an axis whose length is the provided argument.
628            If is IterVar, return an axis whose length is extracted from the
629                           IterVar's extent domain.
630        """
631        return VirtualAxis(var)
632
633    reduce_axis = axis
634
635    def define_split(self, name, axis, policy='factors', **kwargs):
636        """Define a new tunable knob which splits an axis into a list of axes
637
638        Parameters
639        ----------
640        name: str
641            name to index the entity of this space
642        axis: tvm.schedule.IterVar
643            axis to split
644        policy: str
645            name of policy.
646            If is 'factors', the tuner will try all divisible factors.
647            If is 'power2', the tuner will try power-of-two factors less or equal to the length.
648            If is 'verbose', the tuner will try all candidates in above two policies.
649            If is 'candidate', try given candidates.
650        kwargs: dict
651            extra arguments for policy
652            max_factor: int
653                the maximum split factor.
654            filter: function(int) -> bool
655                see examples below for how to use filter.
656            num_outputs: int
657                the total number of axis after split.
658            no_tail: bool
659                should we only include divisible numbers as split factors.
660            candidate: list
661                (policy=candidate) manual candidate list.
662
663        Examples
664        --------
665        >>> # use custom candidates
666        >>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
667
668        >>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
669        >>> cfg.define_split('tile_y', y, policy='factors', filter=lambda x: x.size[-1] <= 4)
670        """
671        axes = [axis]
672        return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)
673
674    def define_reorder(self, name, axes, policy, **kwargs):
675        """Define a new tunable knob which reorders a list of axes
676
677        Parameters
678        ----------
679        name: str
680            name to index the entity of this space
681        axes: Array of tvm.schedule.IterVar
682            axes to reorder
683        policy: str
684            name of policy
685            If is 'identity', do an identity permutation.
686            If is 'all', try all permutations.
687            If is 'interval_all', try all permutations of an interval of axes.
688            If is 'candidate', try listed candidate.
689            If is 'interleave', interleave chains of spatial axes and chains of reduction axes.
690        kwargs: dict
691            extra arguments for policy
692        """
693        return self._add_new_transform(ReorderSpace, name, axes, policy, **kwargs)
694
695    def define_annotate(self, name, axes, policy, **kwargs):
696        """Define a new tunable knob which annotates a list of axes
697
698        Parameters
699        ----------
700        name: str
701            name to index the entity of this space
702        axes: Array of tvm.schedule.IterVar
703            axes to annotate
704        policy: str
705            name of policy
706            If is 'unroll', unroll the axes.
707            If is 'try_unroll', try to unroll the axes.
708            If is 'try_unroll_vec', try to unroll or vectorize the axes.
709            If is 'bind_gpu', bind the first few axes to gpu threads.
710            If is 'locate_cache', choose n axes to attach shared/local cache.
711        kwargs: dict
712            extra arguments for policy
713        """
714        return self._add_new_transform(AnnotateSpace, name, axes, policy, **kwargs)
715
716    def define_knob(self, name, candidate):
717        """Define a tunable knob with a list of candidates
718
719        Parameters
720        ----------
721        name: str
722            name key of that option
723        candidate: list
724            list of candidates
725        """
726        return self._add_new_transform(OtherOptionSpace, name, [], None, candidate=candidate)
727
728    def add_flop(self, flop):
729        """Add float operation statistics for this tuning task
730
731        Parameters
732        ---------
733        flop: int or float
734            number of float operations
735        """
736        self.flop += flop
737
738    def raise_error(self, msg):
739        """register error in config
740        Using this to actively detect error when scheudling.
741        Otherwise these error will occur during runtime, which
742        will cost more time.
743
744        Parameters
745        ----------
746        msg: str
747        """
748        self.errors.append(msg)
749
750    def valid(self):
751        """Check whether the config meets all the constraints
752        Note: This check should be called after instantiation of task,
753              because the ConfigEntity/ConfigSpace collects errors during instantiation
754
755        Returns
756        -------
757        valid: bool
758            whether the config meets all the constraints
759        """
760        return not bool(self.errors)
761
762    def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
763        """Add a new transform space in template"""
764        if self._collect:
765            # convert schedule axis to space definition axis
766            axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) for x in axes]
767
768            # add subspace (knob)
769            space = space_class(axes, policy, **kwargs)
770            self.space_map[name] = space
771            self._entity_map[name] = space[0]
772            return [Axis(space, i) for i in range(space.num_output)]
773        return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))]
774
775    def __len__(self):
776        if self._length is None:
777            self._length = int(np.prod([len(x) for x in self.space_map.values()]))
778        return self._length
779
780    def get(self, index):
781        """Get a config entity with detailed parameters from this space
782
783        Parameters
784        ----------
785        index: int
786            index in the space
787        """
788        entities = OrderedDict()
789        t = index
790        for name, space in self.space_map.items():
791            entities[name] = space[t % len(space)]
792            t //= len(space)
793        ret = ConfigEntity(index, self.code_hash, self.template_key, entities, self._constraints)
794        return ret
795
796    def __iter__(self):
797        return self._entity_map.__iter__()
798
799    def __getitem__(self, name):
800        """get the transform entity(knob) of this entity by name
801           do not use this to get a ConfigEntity of this space (should use ConfigSpace.get instead)
802
803        Parameters
804        ----------
805        name: str
806            name of the transform
807        """
808        return self._entity_map[name]
809
810    def __repr__(self):
811        res = "ConfigSpace (len=%d, space_map=\n" % len(self)
812        for i, (name, space) in enumerate(self.space_map.items()):
813            res += "  %2d %s: %s\n" % (i, name, space)
814        return res + ")"
815
816
817_ann_to_number = {
818    'none': 0, 'vec': 1, 'unroll': 2,
819    'blockIdx.x': 3, 'blockIdx.y': 4, 'blockIdx.z': 5,
820    'threadIdx.x': 6, 'threadIdx.y': 7, 'threadIdx.z': 8,
821    'vthread': 9, 'fuse': 10
822}
823
824class ConfigEntity(ConfigSpace):
825    """A configuration with detailed parameters
826
827    Parameters
828    ----------
829    index: int
830        index of this config in space
831    code_hash: str
832        hash of schedule code
833    template_key : str
834        The specific template key
835    entity_map: dict
836        map name to transform entity
837    constraints : list
838        List of constraints
839    """
840    def __init__(self, index, code_hash, template_key, entity_map, constraints):
841        super(ConfigEntity, self).__init__()
842        self.index = index
843        self.template_key = template_key
844        self._collect = False
845        self._entity_map = entity_map
846        self._space_map = None
847        self._constraints = constraints
848        self.code_hash = code_hash
849
850    def get_flatten_feature(self):
851        """ flatten entities to a numerical one-dimensional feature vector
852
853        Returns
854        -------
855        fea: np.array
856            one dimensional float32 array
857        """
858        fea = []
859        for _, v in self._entity_map.items():
860            if isinstance(v, SplitEntity):
861                fea.extend(v.size)
862            elif isinstance(v, ReorderEntity):
863                # use a naive way: directly copy the permutation
864                fea.extend(v.perm)
865            elif isinstance(v, AnnotateEntity):
866                # one-hot encoding
867                for ann in v.anns:
868                    tmp = [0] * len(_ann_to_number)
869                    tmp[_ann_to_number[ann]] = 1
870                    fea.extend(tmp)
871            elif isinstance(v, OtherOptionEntity):
872                fea.append(v.val)
873        return np.array(fea, dtype=np.float32)
874
875    def get_other_option(self):
876        """
877        Returns
878        -------
879        other_option: dict
880            other tunable parameters (tunable parameters defined by `cfg.define_knob`)
881        """
882        return {x: x.val for x in self._entity_map.values() if isinstance(x, OtherOptionEntity)}
883
884    def to_json_dict(self):
885        """convert to a json serializable dictionary
886
887        Return
888        ------
889        json_dict: dict
890            a json serializable dictionary
891        """
892        ret = {}
893        ret['i'] = int(self.index)
894        ret['t'] = self.template_key
895        ret['c'] = self.code_hash
896        entity_map = []
897        for k, v in self._entity_map.items():
898            if isinstance(v, SplitEntity):
899                entity_map.append((k, 'sp', v.size))
900            elif isinstance(v, ReorderEntity):
901                entity_map.append((k, 're', v.perm))
902            elif isinstance(v, AnnotateEntity):
903                entity_map.append((k, 'an', v.anns))
904            elif isinstance(v, OtherOptionEntity):
905                entity_map.append((k, 'ot', v.val))
906            else:
907                raise RuntimeError("Invalid entity instance: " + v)
908        ret['e'] = entity_map
909        return ret
910
911    @staticmethod
912    def from_json_dict(json_dict):
913        """Build a ConfigEntity from json serializable dictionary
914
915        Parameters
916        ----------
917        json_dict: dict
918            Json serializable dictionary. This should be the return value
919            of :any:`to_json_dict`.
920
921        Returns
922        -------
923        config: ConfigEntity
924            The corresponding config object
925
926        """
927        index = json_dict["i"]
928        code_hash = json_dict["c"]
929        template_key = json_dict["t"]
930        constraints = []
931        entity_map = OrderedDict()
932
933        for item in json_dict["e"]:
934            key, knob_type, knob_args = item
935            if knob_type == 'sp':
936                entity = SplitEntity(knob_args)
937            elif knob_type == 're':
938                entity = ReorderEntity(knob_args)
939            elif knob_type == 'an':
940                entity = AnnotateEntity(knob_args)
941            elif knob_type == 'ot':
942                entity = OtherOptionEntity(knob_args)
943            else:
944                raise RuntimeError("Invalid config knob type: " + knob_type)
945            entity_map[str(key)] = entity
946
947        return ConfigEntity(index, code_hash, template_key, entity_map, constraints)
948
949    def __repr__(self):
950        return "%s,%s,%s,%d" % (str(self._entity_map)[12:-1], self.template_key,
951                                self.code_hash, self.index)
952
953
954class FallbackConfigEntity(ConfigSpace):
955    """The config entity created to support fallback"""
956
957    def __init__(self):
958        super(FallbackConfigEntity, self).__init__()
959        self.is_fallback = True
960
961    def fallback_split(self, name, constraints):
962        """Fallback a split knob
963
964        Parameters
965        ----------
966        name: str
967            name of the knob
968        constraints: List of int
969            The maximum tile size for every dimension. Value `-1` means no constraint.
970
971        Examples
972        --------
973        If you use cfg.define_split('tile_0', 128, num_outputs=3),
974        Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [4, 8, 4]
975
976        If you use cfg.define_split('tile_0', 49, num_outputs=3),
977        Then cfg.fallback_split('tile_0', [-1, 8, 4]) will give you cfg['tile_0'].size = [7, 7, 1]
978        """
979        space = self.space_map[name]
980        assert isinstance(space, SplitSpace)
981        assert len(constraints) == space.num_output
982
983        # '-1' means no constraint
984        constraints = [x if x != -1 else 1e10 for x in constraints]
985
986        entity = self._entity_map[name]
987        now = space.product
988
989        for i in reversed(range(space.num_output)):
990            factors = get_factors(now)
991
992            find = len(factors) - 1
993            for j, f in enumerate(factors):
994                if f > constraints[i]:
995                    find = j - 1
996                    break
997
998            if find >= 0:
999                entity.size[i] = factors[find]
1000                now //= factors[find]
1001            else:
1002                raise RuntimeError("Cannot find feasible fallback split entity for node: " + name)
1003
1004    def fallback_with_reference_log(self, ref_log):
1005        """A data driven fallback mechanism.
1006        We use tuned parameters from TopHub as reference data.
1007        For an unseen shape, we find the most similar tuned one from TopHub and
1008        mimic its parameters.
1009        Note that we are not matching by workload (e.g., input size, kernel size),
1010        but instead matching by configuration space. The idea is that if two workloads have
1011        similar configuration space, their optimal configurations are also likely to be similar.
1012
1013        Parameters
1014        ----------
1015        ref_log: List of (MeasureInput, MeasureResult)
1016            The reference log
1017        """
1018        knob_names = [x for x in self.space_map.keys() if
1019                      isinstance(self.space_map[x], SplitSpace)]
1020
1021        # find best match config in reference data by matching tiling factors
1022        factor_list = []
1023        for knob_name in knob_names:
1024            factor_list.append(get_factors(self.space_map[knob_name].product))
1025
1026        best_match_cfg = None
1027        best_match_score = 0
1028        for inp, _ in ref_log:
1029            match_score = 0
1030            for i, knob_name in enumerate(knob_names):
1031                factors = get_factors(int(np.prod(inp.config[knob_name].size)))
1032                match_score += (float(len(set(factor_list[i]).intersection(factors))) /
1033                                len(factor_list[i]))
1034
1035                if match_score > best_match_score:
1036                    best_match_score, best_match_cfg = match_score, inp.config
1037
1038        if best_match_cfg is None:
1039            return
1040
1041        # mimic its tiling strategy
1042        for knob_name in knob_names:
1043            constraint = list(best_match_cfg[knob_name].size)
1044            constraint[0] = -1
1045            self.fallback_split(knob_name, constraint)
1046
1047        # copy other knobs
1048        for knob_name in self.space_map.keys():
1049            if not isinstance(self.space_map[knob_name], SplitSpace):
1050                self._entity_map[knob_name] = best_match_cfg[knob_name]
1051
1052    def __setitem__(self, name, entity):
1053        """set the entity(knob) of by name
1054
1055        Parameters
1056        ----------
1057        name: str
1058            name of the entity
1059        entity: SplitEntity, ReorderEntity, AnnotateEntity, OtherOptionEntity
1060            value of the entity
1061        """
1062        self._entity_map[name] = entity
1063
1064    def __repr__(self):
1065        return "%s,%s,%s" % (str(self._entity_map)[12:-1], self.template_key, self.code_hash)
1066