1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=unused-variable,not-callable
18"""Definition of task function.
19
20Task can be constructed from tuple of func, args, and kwargs.
21func is a state-less function, or a string that
22registers the standard task.
23"""
24import numpy as np
25
26from tvm.target import Target
27from tvm import runtime
28from tvm.ir import container
29from tvm.tir import expr
30from tvm.te import tensor, placeholder
31
32
33from ..util import get_const_int, get_const_tuple
34from .dispatcher import DispatchContext, ApplyConfig
35from .space import ConfigSpace
36
37
38def _raise_error(*args, **kwargs):  # pylint: disable=unused-argument
39    raise RuntimeError(
40        "The function of this task is not found. Possibly the function "
41        "of this task is registered in another python file "
42        "which is not imported in this run"
43    )
44
45
46def serialize_args(args):
47    """serialize arguments of a topi function to a hashable tuple.
48
49    Parameters
50    ----------
51    args: list of hashable or Tensor
52    """
53
54    def _encode(x):
55        if isinstance(x, tensor.Tensor):
56            return ("TENSOR", get_const_tuple(x.shape), x.dtype)
57        if isinstance(x, (tuple, list, container.Array)):
58            return tuple([_encode(a) for a in x])
59        if isinstance(x, (str, int, float, np.int, np.float, expr.Var, expr.Any)):
60            return x
61        if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
62            return x.value
63        if isinstance(x, runtime.container.String):
64            return str(x)
65        if x is None:
66            return None
67        raise RuntimeError(
68            'Do not support type "%s" in argument. Consider to use'
69            "primitive types or tvm.tir.Var only" % type(x)
70        )
71
72    ret = []
73    for t in args:
74        ret.append(_encode(t))
75    return tuple(ret)
76
77
78def deserialize_args(args):
79    """The inverse function of :code:`serialize_args`.
80
81    Parameters
82    ----------
83    args: list of hashable or Tensor
84    """
85    ret = []
86    for t in args:
87        if isinstance(t, tuple) and t[0] == "TENSOR":
88            ret.append(placeholder(shape=t[1], dtype=t[2]))
89        else:
90            ret.append(t)
91    return ret
92
93
94def args_to_workload(args, task_name=None):
95    """Convert argument list to hashable workload tuple.
96    This function will convert list to tuple, tvm node to python value and
97    flatten te.tensor.Tensor to a tuple
98
99    Parameters
100    ----------
101    task_name : str
102        The AutoTVM task name
103
104    args : list of args
105        The arguments to the function
106
107    Returns
108    -------
109    ret: hashable
110        The hashable value
111    """
112    return (task_name,) + serialize_args(args) if task_name is not None else serialize_args(args)
113
114
115class Task(object):
116    """A Tunable Task
117
118    Parameters
119    ----------
120    name: str
121        The name of the task.
122    args: Tuple
123        Positional argument of func
124    """
125
126    def __init__(self, name, args):
127        self.name = name
128        self.args = args
129        self.kwargs = {}  # currently unused
130
131        # init null config space
132        self.config_space = None
133        self.func = TASK_TABLE.get(name, _raise_error)
134
135        # auxiliary info, available after `init_space` is called
136        self.flop = None
137        self.target = None
138        self.target_host = None
139
140    @property
141    def workload(self):
142        return (self.name,) + serialize_args(self.args)
143
144    def instantiate(self, config):
145        """Instantiate this task function (template) with a config.
146        Returns corresponding schedule.
147
148        Parameters
149        ----------
150        config: template.ConfigEntity
151            parameter config for this template
152
153        Returns
154        -------
155        sch: tvm.te.schedule.Schedule
156            The tvm schedule
157        arg_bufs: Array of te.tensor.Tensor
158            The input/output buffers
159        """
160        config.flop = 0
161        with ApplyConfig(config):
162            sch, arg_bufs = self.func(*self.args, **self.kwargs)
163        if not self.flop:
164            config.flop = config.flop or compute_flop(sch)
165            self.flop = config.flop
166        return sch, arg_bufs
167
168    def __getstate__(self):
169        # custom pickle implementation is required for
170        # some unpickable local task functions.
171        # So we only pickle the name of the function
172        # and restore the function by name when unpickling it.
173        return {
174            "name": self.name,
175            "args": self.args,
176            "kwargs": self.kwargs,
177            "config_space": self.config_space,
178            "flop": self.flop,
179            "target": self.target,
180            "target_host": self.target_host,
181        }
182
183    def __setstate__(self, state):
184        self.name = state["name"]
185        self.args = state["args"]
186        self.kwargs = state["kwargs"]
187        self.config_space = state["config_space"]
188        self.func = TASK_TABLE.get(state["name"], _raise_error)
189        self.flop = state["flop"]
190        self.target = state["target"]
191        self.target_host = state["target_host"]
192
193    def __repr__(self):
194        return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
195            self.name,
196            self.args,
197            self.kwargs,
198            self.workload,
199        )
200
201
202TASK_TABLE = {}
203
204
205class TaskTemplate(object):
206    """
207    Task template is used to creates a tunable AutoTVM task.
208
209    It can be defined by a pair of compute and schedule function using
210    `_register_task_compute` and `_register_task_schedule`,
211    or by a customized task creation function that is more flexible using
212    `_register_customized_task`.
213
214    Note that when customized func is registered, compute and schedule function
215    will be ignored
216    """
217
218    def __init__(self):
219        self.fcompute = None
220        self.fschedule = None
221        self.fcustomized = None
222
223    def __call__(self, *args, **kwargs):
224        args = deserialize_args(args)
225        if self.fcustomized is None:
226            return self._default_func(*args, **kwargs)
227        assert callable(self.fcustomized)
228        return self.fcustomized(*args, **kwargs)
229
230    def _default_func(self, *args, **kwargs):
231        assert callable(self.fcompute) and callable(self.fschedule)
232        out = self.fcompute(*args, **kwargs)
233        arg_bufs = [out] + self._get_inputs(out)
234        s = self.fschedule([out])
235        return s, arg_bufs
236
237    @staticmethod
238    def _get_inputs(out):
239        inputs = []
240        queue = [out]
241        hash_set = set()
242        while queue:
243            t = queue.pop(0)
244            if isinstance(t.op, tensor.PlaceholderOp):
245                inputs.append(t)
246            else:
247                input_tensors = [t for t in t.op.input_tensors if t not in hash_set]
248                queue.extend(input_tensors)
249                hash_set.update(input_tensors)
250        return inputs
251
252
253def _register_task_compute(name, func=None):
254    """Register compute function to autotvm task
255
256    Parameters
257    ----------
258    name: str
259        The task name
260
261    func: None or callable
262        If it is None, return a decorator.
263        If is callable, decorate this function.
264
265    Returns
266    -------
267    decorator: callable
268        A decorator
269    """
270
271    def _do_reg(f):
272        if name not in TASK_TABLE:
273            TASK_TABLE[name] = TaskTemplate()
274        tmpl = TASK_TABLE[name]
275        if tmpl.fcompute is not None:
276            raise ValueError("Compute is already registered in autoTVM task %s" % name)
277        tmpl.fcompute = f
278        return f
279
280    if func:
281        return _do_reg(func)
282    return _do_reg
283
284
285def _register_task_schedule(name, func=None):
286    """Register schedule function to autotvm task
287
288    Parameters
289    ----------
290    name: str
291        The task name
292
293    func: None or callable
294        If it is None, return a decorator.
295        If is callable, decorate this function.
296
297    Returns
298    -------
299    decorator: callable
300        A decorator
301    """
302
303    def _do_reg(f):
304        if name not in TASK_TABLE:
305            TASK_TABLE[name] = TaskTemplate()
306        tmpl = TASK_TABLE[name]
307        if tmpl.fschedule is not None:
308            raise ValueError("Schedule is already registered in autoTVM task %s" % name)
309        tmpl.fschedule = f
310        return f
311
312    if func:
313        return _do_reg(func)
314    return _do_reg
315
316
317def _register_customized_task(name, func=None):
318    """Register a customized function to AutoTVM task.
319
320    Parameters
321    ----------
322    name: str
323        The task name
324
325    func: None or callable
326        If it is None, return a decorator.
327        If is callable, decorate this function.
328
329    Returns
330    -------
331    decorator: callable
332        A decorator
333    """
334
335    def _do_reg(f):
336        if name not in TASK_TABLE:
337            TASK_TABLE[name] = TaskTemplate()
338        tmpl = TASK_TABLE[name]
339        if tmpl.fcustomized is not None:
340            raise ValueError("Customized func is already registered in autoTVM task %s" % name)
341        tmpl.fcustomized = f
342        return f
343
344    if func:
345        return _do_reg(func)
346    return _do_reg
347
348
349def template(task_name, func=None):
350    """Decorate a function as a tunable schedule template.
351
352    Parameters
353    ----------
354    task_name: str
355        The task name
356
357    func: None or callable
358        A callable template function.
359        If it is None, return a decorator.
360        If is callable, decorate this function.
361
362    Returns
363    -------
364    func: callable
365        The decorated function
366
367    Examples
368    --------
369    The following code is a tunable template for a blocked matrix multiplication
370
371    .. code-block:: python
372
373        @autotvm.template("matmul")
374        def matmul(N, L, M, dtype):
375            A = te.placeholder((N, L), name='A', dtype=dtype)
376            B = te.placeholder((L, M), name='B', dtype=dtype)
377
378            k = te.reduce_axis((0, L), name='k')
379            C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name='C')
380            s = te.create_schedule(C.op)
381
382            # schedule
383            y, x = s[C].op.axis
384            k = s[C].op.reduce_axis[0]
385
386            ##### define space begin #####
387            cfg = autotvm.get_config()
388            cfg.define_split("tile_y", y, num_outputs=2)
389            cfg.define_split("tile_x", x, num_outputs=2)
390            ##### define space end #####
391
392            # schedule according to config
393            yo, yi = cfg["tile_y"].apply(s, C, y)
394            xo, xi = cfg["tile_x"].apply(s, C, x)
395
396            s[C].reorder(yo, xo, k, yi, xi)
397
398            return s, [A, B, C]
399    """
400
401    def _decorate(f):
402        def wrapper(*args, **kwargs):
403            assert not kwargs, "Do not support kwargs in template function call"
404            workload = args_to_workload(args, task_name)
405            tgt = Target.current()
406            cfg = DispatchContext.current.query(tgt, workload)
407            with ApplyConfig(cfg):
408                return f(*args, **kwargs)
409
410        _register_customized_task(task_name, f)
411        return wrapper
412
413    if func:
414        return _decorate(func)
415    return _decorate
416
417
418def create(task_name, args, target, target_host=None):
419    """Create a tuning task and initialize its search space
420
421    Parameters
422    ----------
423    task_name : str
424        The AutoTVM task name
425    args : List
426        Positional arguments
427    target : Target
428        The compilation target
429    target_host: Target, optional
430        The compilation target for host side
431
432    Returns
433    -------
434    tsk: Task
435        a task object
436    """
437    args = serialize_args(args)
438    ret = Task(task_name, args)
439
440    if isinstance(target, str):
441        target = Target(target)
442
443    # init config space
444    ret.config_space = ConfigSpace()
445
446    ctx = ApplyConfig(ret.config_space)
447    with ctx:
448        with target:
449            sch, _ = ret.func(*args)
450            ret.config_space.code_hash = getattr(sch, "code_hash", None)
451
452    ret.flop = ret.config_space.flop or compute_flop(sch)
453    ret.target = target
454    ret.target_host = target_host
455
456    return ret
457
458
459def get_config():
460    """Get current config object
461
462    Returns
463    -------
464    cfg: ConfigSpace or ConfigEntity
465        The current config
466    """
467    tgt = Target.current(allow_none=True)
468    return DispatchContext.current.query(tgt, None)
469
470
471class FlopCalculationError(RuntimeError):
472    """Error happens when estimating FLOP for a compute op"""
473
474
475def compute_flop(sch):
476    """Calculate number of FLOP (floating number operations) of the compute ops in a schedule
477
478    Parameters
479    ----------
480    sch: tvm.te.schedule.Schedule
481        schedule
482
483    Returns
484    -------
485    flop: int
486        number of FLOP in this schedule
487    """
488
489    def _prod_length(axes):
490        """compute product of the lengths of a list of axes"""
491        try:
492            num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes]))
493        except ValueError:
494            raise FlopCalculationError("The length of axis is not constant. ")
495        return num_iter
496
497    def _count_flop(exp):
498        """compute flop for a single expression"""
499        if isinstance(exp, expr.Reduce):
500            num_iter = _prod_length(exp.axis)
501            combiner = exp.combiner.result
502            source = exp.source
503            if len(combiner) != 1:
504                raise FlopCalculationError("Found multiple output in the combiner of reduce op")
505            if len(source) != 1:
506                raise FlopCalculationError("Found multiple output in the source of reduce op")
507            return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
508        if isinstance(exp, (expr.FloatImm, expr.IntImm)):
509            return 0
510        if isinstance(exp, expr.Cast):
511            return _count_flop(exp.value)
512        if isinstance(exp, expr.Var):
513            return 0
514        if isinstance(
515            exp,
516            (
517                expr.Add,
518                expr.Sub,
519                expr.Mul,
520                expr.Div,
521                expr.Mod,
522                expr.FloorDiv,
523                expr.FloorMod,
524                expr.Max,
525                expr.Min,
526                expr.EQ,
527                expr.NE,
528                expr.LT,
529                expr.LE,
530                expr.GT,
531                expr.GE,
532                expr.And,
533                expr.Or,
534                expr.Not,
535            ),
536        ):
537            base = 1
538
539            if isinstance(exp, expr.Not):  # unary
540                return base + _count_flop(exp.a)
541
542            return base + _count_flop(exp.a) + _count_flop(exp.b)
543        if isinstance(exp, expr.Select):
544            return _count_flop(exp.condition) + max(
545                _count_flop(exp.true_value), _count_flop(exp.false_value)
546            )
547        if isinstance(exp, expr.ProducerLoad):
548            # Ignore flops from indexing expressions.
549            return 0
550
551        if isinstance(exp, expr.Call):
552            return sum([_count_flop(x) for x in exp.args])
553
554        raise FlopCalculationError("Found unsupported operator in the compute expr")
555
556    def traverse(ops):
557        """accumulate flops"""
558        ret = 0
559        for op in ops:
560            if isinstance(op, tensor.ComputeOp):
561                num_element = _prod_length(op.axis)
562
563                body = op.body
564                if len(body) != 1:
565                    raise FlopCalculationError("Found multiple output in the compute")
566                exp = body[0]
567
568                ret += num_element * _count_flop(exp)
569                ret += traverse([t.op for t in op.input_tensors])
570
571            elif isinstance(op, tensor.PlaceholderOp):
572                pass
573            else:
574                raise FlopCalculationError(
575                    "Only support te.compute currently. "
576                    "Other ops like tvm.te.scan/te.extern is not supported"
577                )
578        return ret
579
580    try:
581        ret = traverse(sch.outputs)
582    except FlopCalculationError as exc:
583        raise RuntimeError(
584            "FLOP estimator fails for this operator. Error msg: "
585            + str(exc)
586            + ". Please use `cfg.add_flop` to manually set "
587            "FLOP for this operator"
588        )
589
590    if ret == 0:
591        raise RuntimeError(
592            "Cannot find float number operation in this operator. "
593            "Please use `cfg.add_flop` to manually set "
594            "FLOP for this operator"
595        )
596    return ret
597