1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Abstraction for array data structures."""
18from numbers import Integral
19import tvm._ffi
20
21from tvm._ffi.base import string_types
22from tvm.runtime import Object, convert
23from tvm.ir import PrimExpr, PointerType, PrimType
24from . import _ffi_api
25
26
27@tvm._ffi.register_object("tir.Buffer")
28class Buffer(Object):
29    """Symbolic data buffer in TVM.
30
31    Buffer provide a way to represent data layout
32    specialization of data structure in TVM.
33
34    Do not construct directly, use :py:func:`~decl_buffer` instead.
35    See the documentation of :py:func:`decl_buffer` for more details.
36
37    See Also
38    --------
39    decl_buffer : Declare a buffer
40    """
41
42    READ = 1
43    WRITE = 2
44
45    def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
46        """Get an access pointer to the head of buffer.
47
48        This is the recommended method to get buffer data
49        ptress when interacting with external functions.
50
51        Parameters
52        ----------
53        access_mask : int
54            The access pattern MASK. Indicate whether the
55            access will read or write to the data content.
56
57        ptr_type : str, optional
58            The data type of the result pointer. Do not specify
59            unless we want to cast pointer to specific type.
60
61        content_lanes: int, optional
62            The number of lanes for the data type. This value
63            is greater than one for vector types.
64
65        offset: Expr, optional
66            The offset of pointer. We can use it to offset by
67            the number of elements from the address of ptr.
68
69        Examples
70        --------
71        .. code-block:: python
72
73          # Get access ptr for read
74          buffer.access_ptr("r")
75          # Get access ptr for read/write with bitmask
76          buffer.access_ptr(Buffer.READ | Buffer.WRITE)
77          # Get access ptr for read/write with str flag
78          buffer.access_ptr("rw")
79          # Get access ptr for read with offset
80          buffer.access_ptr("r", offset = 100)
81        """
82        if isinstance(access_mask, string_types):
83            mask = 0
84            for value in access_mask:
85                if value == "r":
86                    mask = mask | Buffer.READ
87                elif value == "w":
88                    mask = mask | Buffer.WRITE
89                else:
90                    raise ValueError("Unknown access_mask %s" % access_mask)
91            access_mask = mask
92        offset = convert(offset)
93        return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type, content_lanes, offset)
94
95    def vload(self, begin, dtype=None):
96        """Generate an Expr that loads dtype from begin index.
97
98        Parameters
99        ----------
100        begin : Array of Expr
101            The beginning index in unit of Buffer.dtype
102
103        dtype : str
104            The data type to be loaded,
105            can be vector type which have lanes that is multiple of Buffer.dtype
106
107        Returns
108        -------
109        load : Expr
110            The corresponding load expression.
111        """
112        begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
113        dtype = dtype if dtype else self.dtype
114        return _ffi_api.BufferVLoad(self, begin, dtype)
115
116    def vstore(self, begin, value):
117        """Generate a Stmt that store value into begin index.
118
119        Parameters
120        ----------
121        begin : Array of Expr
122            The beginning index in unit of Buffer.dtype
123
124        value : Expr
125            The value to be stored.
126
127        Returns
128        -------
129        store : Stmt
130            The corresponding store stmt.
131        """
132        begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin
133        return _ffi_api.BufferVStore(self, begin, value)
134
135
136def decl_buffer(
137    shape,
138    dtype=None,
139    name="buffer",
140    data=None,
141    strides=None,
142    elem_offset=None,
143    scope="",
144    data_alignment=-1,
145    offset_factor=0,
146    buffer_type="",
147):
148    """Declare a new symbolic buffer.
149
150    Normally buffer is created automatically during lower and build.
151    This is only needed if user want to specify their own buffer layout.
152
153    See the note below for detailed discussion on usage of buffer.
154
155    Parameters
156    ----------
157    shape : tuple of Expr
158        The shape of the buffer.
159
160    dtype : str, optional
161        The data type of the buffer.
162
163    name : str, optional
164        The name of the buffer.
165
166    data : Var, optional
167        The data pointer in the buffer.
168
169    strides: array of Expr
170        The stride of the buffer.
171
172    elem_offset: Expr, optional
173        The beginning offset of the array to data.
174        In terms of number of elements of dtype.
175
176    scope: str, optional
177        The storage scope of the buffer, if not global.
178        If scope equals empty string, it means it is global memory.
179
180    data_alignment: int, optional
181        The alignment of data pointer in bytes.
182        If -1 is passed, the alignment will be set to TVM's internal default.
183
184    offset_factor: int, optional
185        The factor of elem_offset field, when set,
186        elem_offset is required to be multiple of offset_factor.
187        If 0 is pssed, the alignment will be set to 1.
188        if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None.
189
190    buffer_type: str, optional, {"", "auto_broadcast"}
191        auto_broadcast buffer allows one to implement broadcast computation
192        without considering whether dimension size equals to one.
193        TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1.
194
195    Returns
196    -------
197    buffer : Buffer
198        The created buffer
199
200    Example
201    -------
202    Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation,
203
204    .. code-block:: python
205
206        m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2")
207        n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2")
208        o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2")
209        A = te.placeholder((m0, m1, m2), name='A')
210        B = te.placeholder((n0, n1, n2), name='B')
211        C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C')
212        Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast")
213        Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast")
214        s = te.create_schedule(C.op)
215        fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb})
216        ctx = tvm.cpu(0)
217        a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx)
218        b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx)
219        c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx)
220        fadd(a, b, c)
221        tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
222
223    Note
224    ----
225    Buffer data structure reflects the DLTensor structure in dlpack.
226    While DLTensor data structure is very general, it is usually helpful
227    to create function that only handles specific case of data structure
228    and make compiled function benefit from it.
229
230    If user pass strides and elem_offset is passed as None
231    when constructing the function, then the function will be specialized
232    for the DLTensor that is compact and aligned.
233    If user pass a fully generic symbolic array to the strides,
234    then the resulting function becomes fully generic.
235    """
236    # pylint: disable=import-outside-toplevel
237    from .expr import Var
238
239    shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
240    dtype = "float32" if dtype is None else dtype
241    strides = () if strides is None else strides
242    if offset_factor != 0 and elem_offset is None:
243        shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
244        elem_offset = Var("%s_elem_offset" % name, shape_dtype)
245    if data is None:
246        data = Var(name, PointerType(PrimType(dtype)))
247    return _ffi_api.Buffer(
248        data,
249        dtype,
250        shape,
251        strides,
252        elem_offset,
253        name,
254        scope,
255        data_alignment,
256        offset_factor,
257        buffer_type,
258    )
259
260
261@tvm._ffi.register_object("tir.DataProducer")
262class DataProducer(Object):
263    pass
264