1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Tensor intrinsics"""
18from __future__ import absolute_import as _abs
19from . import _api_internal
20from . import api as _api
21from . import expr as _expr
22from . import stmt as _stmt
23from . import make as _make
24from . import tensor as _tensor
25from . import schedule as _schedule
26from .build_module import current_build_config
27from ._ffi.node import NodeBase, register_node
28
29
30def _get_region(tslice):
31    region = []
32    for idx in tslice.indices:
33        if isinstance(idx, slice):
34            assert idx.step is None
35            region.append(_api.Range(idx.start, idx.stop))
36        else:
37            if isinstance(idx, _schedule.IterVar):
38                begin = idx.var
39            else:
40                begin = idx
41            region.append(_make.range_by_min_extent(begin, 1))
42    return region
43
44@register_node
45class TensorIntrin(NodeBase):
46    """Tensor intrinsic functions for certain computation.
47
48    See Also
49    --------
50    decl_tensor_intrin: Construct a TensorIntrin
51    """
52    def __call__(self, *args, **kwargs):
53        tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
54        scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
55        regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)]
56        reduce_axis = []
57        if "reduce_axis" in kwargs:
58            reduce_axis = kwargs["reduce_axis"]
59            if not isinstance(reduce_axis, (list, tuple)):
60                reduce_axis = [reduce_axis]
61            reduce_axis = _api.convert(reduce_axis)
62        if scalar_inputs:
63            scalar_inputs = _api.convert(scalar_inputs)
64        return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
65
66def decl_tensor_intrin(op,
67                       fcompute,
68                       name="tensor_intrin",
69                       binds=None, scalar_params=None):
70    """Declare a tensor intrinsic function.
71
72    Parameters
73    ----------
74    op: Operation
75        The symbolic description of the intrinsic operation
76
77    fcompute: lambda function of inputs, outputs-> stmt
78        Specifies the IR statement to do the computation.
79        See the following note for function signature of fcompute
80
81        .. note::
82             **Parameters**
83
84             - **ins** (list of :any:`Buffer`) - Placeholder for each inputs
85             - **outs** (list of :any:`Buffer`) - Placeholder for each outputs
86
87             **Returns**
88
89             - **stmt** (:any:`Stmt`, or tuple of three stmts)
90             - If a single stmt is returned, it represents the body
91             - If tuple of three stmts are returned they corresponds to body,
92               reduce_init, reduce_update
93
94    name: str, optional
95        The name of the intrinsic.
96
97    binds: dict of :any:`Tensor` to :any:`Buffer`, optional
98        Dictionary that maps the Tensor to Buffer which specified the data layout
99        requirement of the function. By default, a new compact buffer is created
100        for each tensor in the argument.
101
102    scalar_params: a list of variables used by op, whose values will be passed
103                   as scalar_inputs when the tensor intrinsic is called.
104
105    Returns
106    -------
107    intrin: TensorIntrin
108        A TensorIntrin that can be used in tensorize schedule.
109    """
110    if not isinstance(op, _tensor.Operation):
111        raise TypeError("expect Operation")
112    inputs = op.input_tensors
113    binds = binds if binds else {}
114    tensors = [x for x in inputs]
115    for i in range(op.num_outputs):
116        tensors.append(op.output(i))
117
118    binds_list = []
119    for t in inputs:
120        if not isinstance(t.op, _tensor.PlaceholderOp):
121            raise ValueError("Do not yet support composition op")
122
123    cfg = current_build_config()
124    for t in tensors:
125        buf = (binds[t] if t in binds else
126               _api.decl_buffer(t.shape, t.dtype, t.op.name,
127                                data_alignment=cfg.data_alignment,
128                                offset_factor=cfg.offset_factor))
129        binds_list.append(buf)
130
131    if scalar_params:
132        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
133    else:
134        body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
135        scalar_params = []
136    if isinstance(body, (_expr.Expr, _stmt.Stmt)):
137        body = [body]
138    body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
139    if len(body) < 3:
140        body += [None] * (3 - len(body))
141    return _api_internal._TensorIntrin(
142        name, op, inputs, binds_list, scalar_params, *body)
143