1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=unused-variable
18"""Definition of task function.
19
20Task can be constructed from tuple of func, args, and kwargs.
21func is a state-less function, or a string that
22registers the standard task.
23"""
24
25import numpy as np
26
27from ... import tensor, expr, container, target as _target
28
29from ..util import get_const_int, get_const_tuple, get_func_name
30from .dispatcher import DispatchContext, ApplyConfig, dispatcher
31from .space import ConfigSpace
32
33def _raise_error(*args, **kwargs):  # pylint: disable=unused-argument
34    raise RuntimeError("The function of this task is not found. Possibly the function "
35                       "of this task is registered in another python file "
36                       "which is not imported in this run")
37
38class Task(object):
39    """A Tunable Task
40
41    Parameters
42    ----------
43    name: str
44        The name of the task.
45    args: Tuple
46        Positional argument of func
47    """
48    def __init__(self, name, args):
49        self.name = name
50        self.args = args
51        self.kwargs = {}  # currently unused
52
53        # init null config space
54        self.config_space = None
55        self.func = TASK_TABLE.get(name, _raise_error)
56
57        # auxiliary info, available after `init_space` is called
58        self.workload = None
59        self.flop = None
60        self.target = None
61        self.target_host = None
62
63    def instantiate(self, config):
64        """Instantiate this task function (template) with a config.
65        Returns corresponding schedule.
66
67        Parameters
68        ----------
69        config: template.ConfigEntity
70            parameter config for this template
71
72        Returns
73        -------
74        sch: tvm.schedule.Schedule
75            The tvm schedule
76        arg_bufs: Array of tvm.tensor.Tensor
77            The input/output buffers
78        """
79        config.flop = 0
80        with ApplyConfig(config):
81            sch, arg_bufs = self.func(*self.args, **self.kwargs)
82        if not self.flop:
83            config.flop = config.flop or compute_flop(sch)
84            self.flop = config.flop
85        return sch, arg_bufs
86
87    def __getstate__(self):
88        # custom pickle implementation is required for
89        # some unpickable local task functions.
90        # So we only pickle the name of the function
91        # and restore the function by name when unpickling it.
92        return {
93            "name": self.name,
94            "args": self.args,
95            "kwargs": self.kwargs,
96            "config_space": self.config_space,
97            "workload": self.workload,
98            "flop": self.flop,
99            "target": self.target,
100            "target_host": self.target_host
101        }
102
103    def __setstate__(self, state):
104        self.name = state["name"]
105        self.args = state["args"]
106        self.kwargs = state["kwargs"]
107        self.config_space = state["config_space"]
108        self.func = TASK_TABLE.get(state["name"], _raise_error)
109        self.workload = state["workload"]
110        self.flop = state["flop"]
111        self.target = state["target"]
112        self.target_host = state["target_host"]
113
114    def __repr__(self):
115        return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
116            self.name, self.args, self.kwargs, self.workload
117        )
118
119TASK_TABLE = {
120}
121
122def register(name, func=None, override=False):
123    """Register a task function.
124
125    Parameters
126    ----------
127    name : str
128        The name to identify the task.
129    func : callable
130        The function to be registered.
131    override : bool
132        Whether override existing registration.
133
134    Returns
135    -------
136    func: callable
137        The registered function
138    """
139    def _do_reg(myf):
140        if name in TASK_TABLE and not override:
141            raise ValueError(
142                "Key %s is already registered" % name)
143        TASK_TABLE[name] = myf
144        return myf
145    if func:
146        return _do_reg(func)
147    return _do_reg
148
149def create(func_name, args, target, target_host=None, template_key=None):
150    """Create a tuning task and initialize its search space
151
152    Parameters
153    ----------
154    func_name : str or callable
155        The task function
156    args : List
157        Positional arguments
158    target : Target
159        The compilation target
160    target_host: Target, optional
161        The compilation target for host side
162
163    Returns
164    -------
165    tsk: Task
166        a task object
167    """
168    if callable(func_name):
169        # register this function if it is not registered before
170        func = func_name
171        func_name = func.func_name if hasattr(func, 'func_name') else func.__name__
172        if func_name in TASK_TABLE:
173            assert func == TASK_TABLE[func_name], "Find name conflict in task registration. " \
174                                                  "Consider to choose another name for this task"
175        else:
176            register(func_name, func=func)
177
178    func = TASK_TABLE[func_name]
179    ret = Task(func_name, args)
180
181    if isinstance(target, str):
182        target = _target.create(target)
183
184    # init config space
185    ret.config_space = ConfigSpace()
186    ret.config_space.template_key = template_key or ""
187
188    ctx = ApplyConfig(ret.config_space)
189    with ctx:
190        with target:
191            sch, _ = func(*args)
192            ret.config_space.code_hash = getattr(sch, 'code_hash', None)
193
194    ret.workload = ctx.workload
195    ret.flop = ret.config_space.flop or compute_flop(sch)
196    ret.target = target
197    ret.target_host = target_host
198
199    return ret
200
201def args_to_workload(x, topi_compute_func=None):
202    """Convert argument list to hashable workload tuple.
203    This function will convert list to tuple, tvm node to python value and
204    flatten tvm.tensor.Tensor to a tuple
205
206    Parameters
207    ----------
208    x: primitive hashable types or tensor.Tensor
209        The original value
210    topi_compute_func: topi compute function
211        The function name will be added as first element of the workload tuple
212
213    Returns
214    -------
215    ret: hashable
216        The hashable value
217    """
218    if isinstance(x, tensor.Tensor):
219        workload = get_const_tuple(x.shape) + (x.dtype, )
220    elif isinstance(x, (tuple, list, container.Array)):
221        workload = tuple([args_to_workload(a) for a in x])
222    elif isinstance(x, (str, int, float, np.int, np.float, expr.Var)):
223        workload = x
224    elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
225        workload = x.value
226    elif x is None:
227        workload = 0
228    else:
229        raise RuntimeError('Do not support type "%s" in argument. Consider to use'
230                           'primitive types or tvm.expr.Var only' % type(x))
231    return (get_func_name(topi_compute_func), ) + workload  if topi_compute_func else workload
232
233def template(func):
234    """
235    Decorate a function as a tunable schedule template
236
237    Parameters
238    ----------
239    func: callable
240        A callable template function.
241        Its argument should be hashable values.
242        Its return value should be a Tuple(Schedule, Array of Tensor)
243
244    Returns
245    -------
246    func: callable
247        The decorated function
248
249    Examples
250    --------
251    The following code is a tunable template for a blocked matrix multiplication
252
253    .. code-block:: python
254
255        @autotvm.template
256        def matmul(N, L, M, dtype):
257            A = tvm.placeholder((N, L), name='A', dtype=dtype)
258            B = tvm.placeholder((L, M), name='B', dtype=dtype)
259
260            k = tvm.reduce_axis((0, L), name='k')
261            C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
262            s = tvm.create_schedule(C.op)
263
264            # schedule
265            y, x = s[C].op.axis
266            k = s[C].op.reduce_axis[0]
267
268            ##### define space begin #####
269            cfg = autotvm.get_config()
270            cfg.define_split("tile_y", y, num_outputs=2)
271            cfg.define_split("tile_x", x, num_outputs=2)
272            ##### define space end #####
273
274            # schedule according to config
275            yo, yi = cfg["tile_y"].apply(s, C, y)
276            xo, xi = cfg["tile_x"].apply(s, C, x)
277
278            s[C].reorder(yo, xo, k, yi, xi)
279
280            return s, [A, B, C]
281    """
282    # pylint: disable=unused-variable
283
284    fname = get_func_name(func)
285
286    @register(fname)
287    @dispatcher
288    def config_dispatcher(*args, **kwargs):
289        assert not kwargs, "Do not support kwargs in template function call"
290        return (fname, ) + args_to_workload(args)
291
292    @config_dispatcher.register("")
293    def template_call(cfg, *args, **kwargs):
294        assert not kwargs, "Do not support kwargs in template function call"
295        with ApplyConfig(cfg):
296            return func(*args, **kwargs)
297
298    config_dispatcher.func_name = fname
299    return config_dispatcher
300
301def get_config():
302    """Get current config object
303
304    Returns
305    -------
306    cfg: ConfigSpace or ConfigEntity
307        The current config
308    """
309    return DispatchContext.current.query(None, None)
310
311class FlopCalculationError(RuntimeError):
312    """Error happens when estimating FLOP for a compute op"""
313
314
315def compute_flop(sch):
316    """Calculate number of FLOP (floating number operations) of the compute ops in a schedule
317
318    Parameters
319    ----------
320    sch: tvm.schedule.Schedule
321        schedule
322
323    Returns
324    -------
325    flop: int
326        number of FLOP in this schedule
327    """
328    def _prod_length(axes):
329        """compute product of the lengths of a list of axes"""
330        try:
331            num_iter = int(np.prod([get_const_int(axis.dom.extent) for axis in axes]))
332        except ValueError:
333            raise FlopCalculationError("The length of axis is not constant. ")
334        return num_iter
335
336    def _count_flop(exp):
337        """compute flop for a single expression"""
338        if isinstance(exp, expr.Reduce):
339            num_iter = _prod_length(exp.axis)
340            combiner = exp.combiner.result
341            source = exp.source
342            if len(combiner) != 1:
343                raise FlopCalculationError("Found multiple output in the combiner of reduce op")
344            if len(source) != 1:
345                raise FlopCalculationError("Found multiple output in the source of reduce op")
346            return num_iter * (_count_flop(combiner[0]) + _count_flop(source[0]))
347        if isinstance(exp, (expr.FloatImm, expr.IntImm, expr.UIntImm)):
348            return 0
349        if isinstance(exp, expr.Cast):
350            return _count_flop(exp.value)
351        if isinstance(exp, expr.Var):
352            return 0
353        if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
354                            expr.Div, expr.Mod,
355                            expr.FloorDiv, expr.FloorMod,
356                            expr.Max, expr.Min,
357                            expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
358                            expr.And, expr.Or, expr.Not)):
359            base = 1
360
361            if isinstance(exp, expr.Not):  # unary
362                return base + _count_flop(exp.a)
363
364            return base + _count_flop(exp.a) + _count_flop(exp.b)
365        if isinstance(exp, expr.Select):
366            return _count_flop(exp.condition) + max(_count_flop(exp.true_value),
367                                                    _count_flop(exp.false_value))
368        if isinstance(exp, expr.Call):
369            if exp.call_type == expr.Call.Halide:
370                # Ignore flops from indexing expressions.
371                return 0
372
373            return sum([_count_flop(x) for x in exp.args])
374
375        raise FlopCalculationError("Found unsupported operator in the compute expr")
376
377    def traverse(ops):
378        """accumulate flops"""
379        ret = 0
380        for op in ops:
381            if isinstance(op, tensor.ComputeOp):
382                num_element = _prod_length(op.axis)
383
384                body = op.body
385                if len(body) != 1:
386                    raise FlopCalculationError("Found multiple output in the compute")
387                exp = body[0]
388
389                ret += num_element * _count_flop(exp)
390                ret += traverse([t.op for t in op.input_tensors])
391
392            elif isinstance(op, tensor.PlaceholderOp):
393                pass
394            else:
395                raise FlopCalculationError("Only support tvm.compute currently. "
396                                           "Other ops like tvm.scan/tvm.extern is not supported")
397        return ret
398
399    try:
400        ret = traverse(sch.outputs)
401    except FlopCalculationError as exc:
402        raise RuntimeError("FLOP estimator fails for this operator. Error msg: "
403                           + str(exc) + ". Please use `cfg.add_flop` to manually set "
404                                        "FLOP for this operator")
405
406    if ret == 0:
407        raise RuntimeError("Cannot find float number operation in this operator. "
408                           "Please use `cfg.add_flop` to manually set "
409                           "FLOP for this operator")
410    return ret
411