1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Functions defined in TVM."""
18# pylint: disable=invalid-name,unused-import,redefined-builtin
19from __future__ import absolute_import as _abs
20
21from numbers import Integral as _Integral
22
23from ._ffi.base import string_types
24from ._ffi.object import register_object, Object
25from ._ffi.node import register_node, NodeBase
26from ._ffi.node import convert_to_node as _convert_to_node
27from ._ffi.node_generic import _scalar_type_inference
28from ._ffi.function import Function
29from ._ffi.function import _init_api, register_func, get_global_func, extract_ext_funcs
30from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
31from ._ffi.runtime_ctypes import TVMType
32from . import _api_internal
33from . import make as _make
34from . import expr as _expr
35from . import tensor as _tensor
36from . import schedule as _schedule
37from . import container as _container
38from . import tag as _tag
39
40int8 = "int8"
41int32 = "int32"
42float32 = "float32"
43handle = "handle"
44
45
46def min_value(dtype):
47    """minimum value of dtype
48
49    Parameters
50    ----------
51    dtype : str
52        The data type.
53
54    Returns
55    -------
56    value : tvm.Expr
57        The minimum value of dtype.
58    """
59    return _api_internal._min_value(dtype)
60
61
62def max_value(dtype):
63    """maximum value of dtype
64
65    Parameters
66    ----------
67    dtype : str
68        The data type.
69
70    Returns
71    -------
72    value : tvm.Expr
73        The maximum value of dtype.
74    """
75    return _api_internal._max_value(dtype)
76
77
78def const(value, dtype=None):
79    """construct a constant
80
81    Parameters
82    ----------
83    value : number
84        The content of the constant number.
85
86    dtype : str or None, optional
87        The data type.
88
89    Returns
90    -------
91    const_val: tvm.Expr
92        The result expression.
93    """
94    if dtype is None:
95        dtype = _scalar_type_inference(value)
96    return _api_internal._const(value, dtype)
97
98
99def get_env_func(name):
100    """Get an EnvFunc by a global name.
101
102    Parameters
103    ----------
104    name: str
105        The name of the global function.
106
107    Returns
108    -------
109    env_func : EnvFunc
110        The result env function.
111
112    Note
113    ----
114    EnvFunc is a Node wrapper around
115    global function that can be serialized via its name.
116    This can be used to serialize function field in the language.
117    """
118    return _api_internal._EnvFuncGet(name)
119
120
121def convert(value):
122    """Convert value to TVM node or function.
123
124    Parameters
125    ----------
126    value : python value
127
128    Returns
129    -------
130    tvm_val : Node or Function
131        Converted value in TVM
132    """
133    if isinstance(value, (Function, NodeBase)):
134        return value
135
136    if callable(value):
137        return _convert_tvm_func(value)
138
139    return _convert_to_node(value)
140
141
142def load_json(json_str):
143    """Load tvm object from json_str.
144
145    Parameters
146    ----------
147    json_str : str
148        The json string
149
150    Returns
151    -------
152    node : Node
153        The loaded tvm node.
154    """
155    return _api_internal._load_json(json_str)
156
157
158def save_json(node):
159    """Save tvm object as json string.
160
161    Parameters
162    ----------
163    node : Node
164        A TVM Node object to be saved.
165
166    Returns
167    -------
168    json_str : str
169        Saved json string.
170    """
171    return _api_internal._save_json(node)
172
173
174def var(name="tindex", dtype=int32):
175    """Create a new variable with specified name and dtype
176
177    Parameters
178    ----------
179    name : str
180        The name
181
182    dtype : int
183        The data type
184
185    Returns
186    -------
187    var : Var
188        The result symbolic variable.
189    """
190    return _api_internal._Var(name, dtype)
191
192
193def any(*args):
194    """Create a new experssion of the union of all conditions in the arguments
195
196    Parameters
197    ----------
198    args : list
199        List of symbolic boolean expressions
200
201    Returns
202    -------
203    expr: Expr
204        Expression
205    """
206    if not args:
207        raise ValueError("Any must take at least 1 argument")
208    if len(args) == 1:
209        return args[0]
210    ret = _make._OpOr(args[0], args[1])
211    for i in range(2, len(args)):
212        ret = _make._OpOr(ret, args[i])
213    return ret
214
215
216def all(*args):
217    """Create a new experssion of the intersection of all conditions in the
218      arguments
219
220    Parameters
221    ----------
222    args : list
223        List of symbolic boolean expressions
224
225    Returns
226    -------
227    expr: Expr
228        Expression
229    """
230    if not args:
231        raise ValueError("Any must take at least 1 argument")
232    if len(args) == 1:
233        return args[0]
234    ret = _make._OpAnd(args[0], args[1])
235    for i in range(2, len(args)):
236        ret = _make._OpAnd(ret, args[i])
237    return ret
238
239
240def placeholder(shape, dtype=None, name="placeholder"):
241    """Construct an empty tensor object.
242
243    Parameters
244    ----------
245    shape: Tuple of Expr
246        The shape of the tensor
247
248    dtype: str, optional
249        The data type of the tensor
250
251    name: str, optional
252        The name hint of the tensor
253
254    Returns
255    -------
256    tensor: Tensor
257        The created tensor
258    """
259    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
260    dtype = float32 if dtype is None else dtype
261    return _api_internal._Placeholder(
262        shape, dtype, name)
263
264
265def compute(shape, fcompute, name="compute", tag="", attrs=None):
266    """Construct a new tensor by computing over the shape domain.
267
268    The compute rule is result[axis] = fcompute(axis)
269
270    Parameters
271    ----------
272    shape: Tuple of Expr
273        The shape of the tensor
274
275    fcompute: lambda function of indices-> value
276        Specifies the input source expression
277
278    name: str, optional
279        The name hint of the tensor
280
281    tag: str, optional
282        Additional tag information about the compute.
283
284    attrs: dict, optional
285        The additional auxiliary attributes about the compute.
286
287    Returns
288    -------
289    tensor: Tensor
290        The created tensor
291    """
292    if _tag.TagScope.get_current() is not None:
293        if tag != "":
294            raise ValueError("nested tag is not allowed for now")
295        tag = _tag.TagScope.get_current().tag
296    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
297    # for python3
298    shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
299    ndim = len(shape)
300    code = fcompute.__code__
301
302    out_ndim = ndim
303    if code.co_argcount == 0:
304        arg_names = ["i%d" % i for i in range(ndim)]
305    else:
306        arg_names = code.co_varnames[:code.co_argcount]
307        out_ndim = code.co_argcount
308
309    if out_ndim != len(arg_names):
310        raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
311
312    dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
313    body = fcompute(*[v.var for v in dim_var])
314
315    if isinstance(body, _tensor.TensorIntrinCall):
316        for i, s in enumerate(shape[out_ndim:]):
317            var_name = "ax" + str(i)
318            dim_var.append(_IterVar((0, s), var_name, 4))
319        op_node = _api_internal._TensorComputeOp(name,
320                                                 tag,
321                                                 dim_var,
322                                                 body.reduce_axis,
323                                                 out_ndim,
324                                                 body.intrin,
325                                                 body.tensors,
326                                                 body.regions,
327                                                 body.scalar_inputs)
328    else:
329        if not isinstance(body, (list, tuple)):
330            body = [body]
331        body = convert(body)
332        op_node = _api_internal._ComputeOp(
333            name, tag, attrs, dim_var, body)
334
335    num = op_node.num_outputs
336    outputs = tuple(op_node.output(i) for i in range(num))
337    return outputs[0] if num == 1 else outputs
338
339
340def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attrs=None):
341    """Construct new tensors by scanning over axis.
342
343    Parameters
344    ----------
345    init: Tensor or list of Tensor
346        The initial condition of first init.shape[0] timestamps
347
348    update: Tensor or list of Tensor
349        The update rule of the scan given by symbolic tensor.
350
351    state_placeholder: Tensor or list of Tensor
352        The placeholder variables used by update.
353
354    inputs: Tensor or list of Tensor, optional
355        The list of inputs to the scan. This is not required, but can
356        be useful for the compiler to detect scan body faster.
357
358    name: str, optional
359        The name hint of the tensor
360
361    tag: str, optional
362        Additonal tag information about the compute.
363
364    attrs: dict, optional
365        The additional auxiliary attributes about the compute.
366
367    Returns
368    -------
369    tensor: Tensor or list of Tensors
370        The created tensor or tuple of tensors it it contains multiple outputs.
371
372    Example
373    -------
374    .. code-block:: python
375
376      # The following code is equivalent to numpy.cumsum
377      m = tvm.var("m")
378      n = tvm.var("n")
379      X = tvm.placeholder((m, n), name="X")
380      s_state = tvm.placeholder((m, n))
381      s_init = tvm.compute((1, n), lambda _, i: X[0, i])
382      s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
383      res = tvm.scan(s_init, s_update, s_state, X)
384    """
385    if _tag.TagScope.get_current() is not None:
386        if tag != "":
387            raise ValueError("nested tag is not allowed for now")
388        tag = _tag.TagScope.get_current().tag
389    if isinstance(init, _tensor.Tensor):
390        init = [init]
391    if isinstance(update, _tensor.Tensor):
392        update = [update]
393    if isinstance(state_placeholder, _tensor.Tensor):
394        state_placeholder = [state_placeholder]
395    if isinstance(inputs, _tensor.Tensor):
396        inputs = [inputs]
397    if inputs is None:
398        inputs = []
399    if len(init) != len(update) or len(init) != len(state_placeholder):
400        raise ValueError("init, update, state_placeholder must have same length")
401    axis = _IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name, 3)
402    op = _api_internal._ScanOp(name, tag, attrs,
403                               axis, init, update,
404                               state_placeholder, inputs)
405    res = [op.output(i) for i in range(len(update))]
406    return res[0] if len(res) == 1 else res
407
408
409def extern(shape,
410           inputs,
411           fcompute,
412           name="extern",
413           dtype=None,
414           in_buffers=None,
415           out_buffers=None,
416           tag="",
417           attrs=None):
418    """Compute several tensor via extern function.
419
420    Parameters
421    ----------
422    shape: tuple or list of tuples.
423        The shape of the outputs.
424
425    inputs: list of Tensor
426        The inputs
427
428    fcompute: lambda function of inputs, outputs-> stmt
429        Specifies the IR statement to do the computation.
430        See the following note for function signature of fcompute
431
432        .. note::
433             **Parameters**
434
435             - **ins** (list of :any:`Buffer`) - Placeholder for each inputs
436             - **outs** (list of :any:`Buffer`) - Placeholder for each outputs
437
438             **Returns**
439
440             - **stmt** (:any:`Stmt`) - The statement that carries out array computation.
441
442    name: str, optional
443        The name hint of the tensor
444
445    dtype: str or list of str, optional
446        The data types of outputs,
447        by default dtype will be same as inputs.
448
449    in_buffers: Buffer or list of Buffer, optional
450        Input buffers.
451
452    out_buffers: Buffer or list of Buffers, optional
453        Output buffers.
454
455
456    tag: str, optional
457        Additonal tag information about the compute.
458
459    attrs: dict, optional
460        The additional auxiliary attributes about the compute.
461
462    Returns
463    -------
464    tensor: Tensor or list of Tensors
465        The created tensor or tuple of tensors it it contains multiple outputs.
466
467    Example
468    -------
469    In the code below, C is generated by calling external PackedFunc
470    `tvm.contrib.cblas.matmul`
471
472    .. code-block:: python
473
474        A = tvm.placeholder((n, l), name='A')
475        B = tvm.placeholder((l, m), name='B')
476        C = tvm.extern((n, m), [A, B],
477                       lambda ins, outs: tvm.call_packed(
478                          "tvm.contrib.cblas.matmul",
479                            ins[0], ins[1], outs[0], 0, 0), name="C")
480    """
481    if _tag.TagScope.get_current() is not None:
482        if tag != "":
483            raise ValueError("nested tag is not allowed for now")
484        tag = _tag.TagScope.get_current().tag
485    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
486    if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
487        shape = [shape]
488    if in_buffers is not None:
489        in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
490        if len(inputs) != len(in_buffers):
491            raise RuntimeError("Number of inputs and in_buffers mismatch: %d vs %d."
492                               % (len(inputs), len(in_buffers)))
493    if out_buffers is not None:
494        out_buffers = [out_buffers] if not isinstance(out_buffers, list) else out_buffers
495        if len(shape) != len(out_buffers):
496            raise RuntimeError("Number of outputs and out_buffers mismatch: %d vs %d."
497                               % (len(shape), len(out_buffers)))
498    input_placeholders = in_buffers or []
499    output_placeholders = out_buffers or []
500    types = set()
501    for t in inputs:
502        if not isinstance(t, _tensor.Tensor):
503            raise ValueError("expect inputs to be tensor")
504        if in_buffers is None:
505            input_placeholders.append(
506                decl_buffer(t.shape, t.dtype, t.op.name))
507        types.add(t.dtype)
508
509    if dtype is None:
510        if len(types) != 1:
511            raise ValueError("Cannot infer output type, please provide dtype argument")
512        infered_type = types.pop()
513        dtype = [infered_type for _ in shape]
514    if isinstance(dtype, str):
515        dtype = [dtype]
516
517    if out_buffers is None:
518        for shp, dt in zip(shape, dtype):
519            output_placeholders.append(decl_buffer(shp, dt, name))
520    body = fcompute(input_placeholders, output_placeholders)
521    if isinstance(body, _expr.Expr):
522        body = _make.Evaluate(body)
523
524    op = _api_internal._ExternOp(name, tag, attrs,
525                                 inputs, input_placeholders,
526                                 output_placeholders, body)
527    res = [op.output(i) for i in range(len(output_placeholders))]
528    return res[0] if len(res) == 1 else res
529
530
531def decl_buffer(shape,
532                dtype=None,
533                name="buffer",
534                data=None,
535                strides=None,
536                elem_offset=None,
537                scope="",
538                data_alignment=-1,
539                offset_factor=0,
540                buffer_type=""):
541    """Declare a new symbolic buffer.
542
543    Normally buffer is created automatically during lower and build.
544    This is only needed if user want to specify their own buffer layout.
545
546    See the note below for detailed discussion on usage of buffer.
547
548    Parameters
549    ----------
550    shape : tuple of Expr
551        The shape of the buffer.
552
553    dtype : str, optional
554        The data type of the buffer.
555
556    name : str, optional
557        The name of the buffer.
558
559    data : Var, optional
560        The data pointer in the buffer.
561
562    strides: array of Expr
563        The stride of the buffer.
564
565    elem_offset: Expr, optional
566        The beginning offset of the array to data.
567        In terms of number of elements of dtype.
568
569    scope: str, optional
570        The storage scope of the buffer, if not global.
571        If scope equals empty string, it means it is global memory.
572
573    data_alignment: int, optional
574        The alignment of data pointer in bytes.
575        If -1 is passed, the alignment will be set to TVM's internal default.
576
577    offset_factor: int, optional
578        The factor of elem_offset field, when set,
579        elem_offset is required to be multiple of offset_factor.
580        If 0 is pssed, the alignment will be set to 1.
581        if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
582
583    buffer_type: str, optional, {"", "auto_broadcast"}
584        auto_broadcast buffer allows one to implement broadcast computation
585        without considering whether dimension size equals to one.
586        TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
587
588    Returns
589    -------
590    buffer : Buffer
591        The created buffer
592
593    Example
594    -------
595    Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
596
597    .. code-block:: python
598
599        m0, m1, m2 = tvm.var("m0"), tvm.var("m1"), tvm.var("m2")
600        n0, n1, n2 = tvm.var("n0"), tvm.var("n1"), tvm.var("n2")
601        o0, o1, o2 = tvm.var("o0"), tvm.var("o1"), tvm.var("o2")
602        A = tvm.placeholder((m0, m1, m2), name='A')
603        B = tvm.placeholder((n0, n1, n2), name='B')
604        C = tvm.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
605        Ab = tvm.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
606        Bb = tvm.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
607        s = tvm.create_schedule(C.op)
608        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
609        ctx = tvm.cpu(0)
610        a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
611        b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
612        c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
613        fadd(a, b, c)
614        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
615
616    Note
617    ----
618    Buffer data structure reflects the DLTensor structure in dlpack.
619    While DLTensor data structure is very general, it is usually helpful
620    to create function that only handles specific case of data structure
621    and make compiled function benefit from it.
622
623    If user pass strides and elem_offset is passed as None
624    when constructing the function, then the function will be specialized
625    for the DLTensor that is compact and aligned.
626    If user pass a fully generic symbolic array to the strides,
627    then the resulting function becomes fully generic.
628    """
629    shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
630    dtype = float32 if dtype is None else dtype
631    strides = () if strides is None else strides
632    if offset_factor != 0 and elem_offset is None:
633        shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
634        elem_offset = var('%s_elem_offset' % name, shape_dtype)
635    if data is None:
636        data = var(name, "handle")
637    return _api_internal._Buffer(
638        data, dtype, shape, strides, elem_offset, name, scope,
639        data_alignment, offset_factor, buffer_type)
640
641def layout(layout_str):
642    """Create a layout node from a string.
643
644    Parameters
645    ----------
646    layout_str : str
647        A layout representation is composed of upper cases, lower cases and numbers,
648        where upper case indicates a primal axis and
649        the corresponding lower case with factor size indicates the subordinate axis.
650        For example, NCHW16c can describe a 5-D tensor of
651        [batch_size, channel, height, width, channel_block].
652        Here subordinate axis channel_block=16 is the factor size of
653        the primal axis C (channel).
654
655    Returns
656    -------
657    layout : Layout
658        The created layout
659    """
660    return _api_internal._Layout(layout_str)
661
662def bijective_layout(src_layout, dst_layout):
663    """Create a bijective layout mapping.
664
665    Parameters
666    ----------
667    src_layout : str or Layout
668        source layout.
669
670    dst_layout : str or Layout
671        destination layout.
672
673    Returns
674    -------
675    bijective_layout : BijectiveLayout
676        The created bijective layout
677    """
678    if isinstance(src_layout, str):
679        src_layout = layout(src_layout)
680    if isinstance(dst_layout, str):
681        dst_layout = layout(dst_layout)
682    return _api_internal._BijectiveLayout(src_layout, dst_layout)
683
684def _IterVar(dom, name, iter_type, thread_tag=''):
685    """Internal function to create IterVar
686
687    Parameters
688    ----------
689    dom : Range
690        The domain of iteration.
691
692    name : str
693        The name of iteration variable.
694
695    iter_type : int
696        The type of iteration.
697
698    thread_tag : str
699        The thread tag of the iteration variable.
700
701    Returns
702    -------
703    iter_var : IterVar
704       The result itervar
705    """
706    if dom is not None:
707        if isinstance(dom, (list, tuple)):
708            if len(dom) != 2:
709                raise TypeError("need to be list of ranges")
710            dom = Range(dom[0], dom[1])
711
712        if not isinstance(dom, _container.Range):
713            raise TypeError("dom need to be Range")
714    name = name if name else 'iter'
715    v = var(name)
716    return _api_internal._IterVar(dom, v, iter_type, thread_tag)
717
718
719def thread_axis(dom=None, tag='', name=''):
720    """Create a new IterVar to represent thread index.
721
722    Parameters
723    ----------
724    dom : Range or str
725        The domain of iteration
726        When str is passed, dom is set to None and str is used as tag
727
728    tag : str, optional
729        The thread tag
730
731    name : str, optional
732        The name of the var.
733
734    Returns
735    -------
736    axis : IterVar
737        The thread itervar.
738    """
739    if isinstance(dom, string_types):
740        tag, dom = dom, None
741    if not tag:
742        raise ValueError("tag must be given as Positional or keyword argument")
743    name = name if name else tag
744    return _IterVar(dom, name, 1, tag)
745
746
747def reduce_axis(dom, name="rv"):
748    """Create a new IterVar for reduction.
749
750    Parameters
751    ----------
752    dom : Range
753        The domain of iteration.
754
755    name : str
756        The name of the variable.
757
758    Returns
759    -------
760    axis : IterVar
761        An iteration variable representing the value.
762    """
763    return _IterVar(dom, name, 2)
764
765
766def comm_reducer(fcombine, fidentity, name="reduce"):
767    """Create a commutative reducer for reduction.
768
769    Parameters
770    ----------
771    fcombine : function(Expr -> Expr -> Expr)
772        A binary function which takes two Expr as input to return a Expr.
773
774    fidentity : function(str -> Expr)
775        A function which takes a type string as input to return a const Expr.
776
777    Returns
778    -------
779    reducer : function
780        A function which creates a reduce expression over axis.
781        There are two ways to use it:
782
783        1. accept (expr, axis, where) to produce an Reduce Expr on
784           specified axis;
785        2. simply use it with multiple Exprs.
786
787    Example
788    -------
789    .. code-block:: python
790
791        n = tvm.var('n')
792        m = tvm.var('m')
793        mysum = tvm.comm_reducer(lambda x, y: x+y,
794            lambda t: tvm.const(0, dtype=t), name="mysum")
795        A = tvm.placeholder((n, m), name='A')
796        k = tvm.reduce_axis((0, m), name='k')
797        B = tvm.compute((n,), lambda i: mysum(A[i, k], axis=k), name='B')
798    """
799    def _reduce_directly(*args):
800        num = len(args)
801        # process `where` is None
802        if num == 3 and args[2] is None:
803            num = 2
804        res = args[0]
805        for i in range(num-1):
806            res = fcombine(res, args[i+1])
807        return res
808
809    def _make_reduce(expr, axis, where=None):
810        code = fcombine.__code__
811        assert fcombine.__code__.co_argcount == 2
812        expr = convert(expr)
813        if isinstance(expr, _container.Array):
814            size = len(expr)
815            larr = []
816            rarr = []
817            dtypes = []
818            for i in range(size):
819                dtype = expr[i].dtype
820                dtypes.append(dtype)
821                lname = code.co_varnames[0] + '_' + str(i)
822                larr.append(var(lname, dtype))
823                rname = code.co_varnames[1] + '_' + str(i)
824                rarr.append(var(rname, dtype))
825            lhs = convert(larr)
826            rhs = convert(rarr)
827            result = fcombine(lhs, rhs)
828            id_elem = fidentity(*dtypes)
829        else:
830            assert isinstance(expr, _expr.Expr)
831            size = 1
832            dtype = expr.dtype
833            lvar = var(code.co_varnames[0], dtype)
834            rvar = var(code.co_varnames[1], dtype)
835            result = [fcombine(lvar, rvar)]
836            id_elem = [fidentity(dtype)]
837            lhs = convert([lvar])
838            rhs = convert([rvar])
839            expr = convert([expr])
840        result = convert(result)
841        id_elem = convert(id_elem)
842        combiner = _make.CommReducer(lhs, rhs, result, id_elem)
843        axis = convert(axis if isinstance(axis, (list, tuple)) else [axis])
844        if where is None:
845            where = convert(True)
846        outputs = tuple(_expr.Reduce(combiner, expr, axis, where, i)
847                        for i in range(size))
848        return outputs[0] if size == 1 else outputs
849
850    # pylint: disable=keyword-arg-before-vararg
851    def reducer(expr, axis, where=None, *args):
852        if isinstance(axis, (_schedule.IterVar, list, tuple)):
853            assert not args
854            return _make_reduce(expr, axis, where)
855        if where is None:
856            assert not args
857            return _reduce_directly(expr, axis)
858        return _reduce_directly(expr, axis, where, *args)
859
860    doc_str = """Create a {0} expression over axis.
861
862              Parameters
863              ----------
864              expr : Expr
865                  The source expression.
866              axis : IterVar
867                  The reduction IterVar axis
868              where : optional, Expr
869                  Filtering predicate of the reduction.
870              Returns
871              -------
872              value : Expr
873                  The result value.
874
875              Example
876              -------
877              .. code-block:: python
878
879                m = tvm.var("m")
880                n = tvm.var("n")
881                A = tvm.placeholder((m, n), name="A")
882                k = tvm.reduce_axis((0, n), name="k")
883
884                # there are two way to use this {0} reducer:
885                # mode 1, accept (expr, axis, where) to produce an Reduce Expr
886                B = tvm.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B")
887
888                # mode 2, simply use it with multiple Exprs:
889                {0}_res = tvm.{0}(m, n)
890              """
891    reducer.__doc__ = doc_str.format(name)
892    return reducer
893
894def div(a, b):
895    """Compute a / b as in C/C++ semantics.
896
897    Parameters
898    ----------
899    a : Expr
900        The left hand operand, known to be non-negative.
901
902    b : Expr
903        The right hand operand, known to be non-negative.
904
905    Returns
906    -------
907    res : Expr
908        The result expression.
909    Note
910    ----
911    When operands are integers, returns truncdiv(a, b).
912    """
913    return _make._OpDiv(a, b)
914
915
916def indexdiv(a, b):
917    """Compute floor(a / b) where a and b are non-negative.
918
919    Parameters
920    ----------
921    a : Expr
922        The left hand operand, known to be non-negative.
923
924    b : Expr
925        The right hand operand, known to be non-negative.
926
927    Returns
928    -------
929    res : Expr
930        The result expression.
931
932    Note
933    ----
934    Use this function to split non-negative indices.
935    This function may take advantage of operands'
936    non-negativeness.
937    """
938    return _make._OpIndexDiv(a, b)
939
940
941def indexmod(a, b):
942    """Compute the remainder of indexdiv. a and b are non-negative.
943
944    Parameters
945    ----------
946    a : Expr
947        The left hand operand, known to be non-negative.
948
949    b : Expr
950        The right hand operand, known to be non-negative.
951
952    Returns
953    -------
954    res : Expr
955        The result expression.
956
957    Note
958    ----
959    Use this function to split non-negative indices.
960    This function may take advantage of operands'
961    non-negativeness.
962    """
963    return _make._OpIndexMod(a, b)
964
965
966def truncdiv(a, b):
967    """Compute the truncdiv of two expressions.
968
969    Parameters
970    ----------
971    a : Expr
972        The left hand operand
973
974    b : Expr
975        The right hand operand
976
977    Returns
978    -------
979    res : Expr
980        The result expression.
981
982    Note
983    ----
984    This is the default integer division behavior in C.
985    """
986    return _make._OpTruncDiv(a, b)
987
988
989def truncmod(a, b):
990    """Compute the truncmod of two expressions.
991
992    Parameters
993    ----------
994    a : Expr
995        The left hand operand
996
997    b : Expr
998        The right hand operand
999
1000    Returns
1001    -------
1002    res : Expr
1003        The result expression.
1004
1005    Note
1006    ----
1007    This is the default integer division behavior in C.
1008    """
1009    return _make._OpTruncMod(a, b)
1010
1011
1012def floordiv(a, b):
1013    """Compute the floordiv of two expressions.
1014
1015    Parameters
1016    ----------
1017    a : Expr
1018        The left hand operand
1019
1020    b : Expr
1021        The right hand operand
1022
1023    Returns
1024    -------
1025    res : Expr
1026        The result expression.
1027    """
1028    return _make._OpFloorDiv(a, b)
1029
1030
1031def floormod(a, b):
1032    """Compute the floormod of two expressions.
1033
1034    Parameters
1035    ----------
1036    a : Expr
1037        The left hand operand
1038
1039    b : Expr
1040        The right hand operand
1041
1042    Returns
1043    -------
1044    res : Expr
1045        The result expression.
1046    """
1047    return _make._OpFloorMod(a, b)
1048
1049
1050_init_api("tvm.api")
1051
1052#pylint: disable=unnecessary-lambda
1053sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum")
1054min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min')
1055max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max')
1056