1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Abstraction for array data structures.""" 18from numbers import Integral 19import tvm._ffi 20 21from tvm._ffi.base import string_types 22from tvm.runtime import Object, convert 23from tvm.ir import PrimExpr, PointerType, PrimType 24from . import _ffi_api 25 26 27@tvm._ffi.register_object("tir.Buffer") 28class Buffer(Object): 29 """Symbolic data buffer in TVM. 30 31 Buffer provide a way to represent data layout 32 specialization of data structure in TVM. 33 34 Do not construct directly, use :py:func:`~decl_buffer` instead. 35 See the documentation of :py:func:`decl_buffer` for more details. 36 37 See Also 38 -------- 39 decl_buffer : Declare a buffer 40 """ 41 42 READ = 1 43 WRITE = 2 44 45 def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0): 46 """Get an access pointer to the head of buffer. 47 48 This is the recommended method to get buffer data 49 ptress when interacting with external functions. 50 51 Parameters 52 ---------- 53 access_mask : int 54 The access pattern MASK. Indicate whether the 55 access will read or write to the data content. 56 57 ptr_type : str, optional 58 The data type of the result pointer. Do not specify 59 unless we want to cast pointer to specific type. 60 61 content_lanes: int, optional 62 The number of lanes for the data type. This value 63 is greater than one for vector types. 64 65 offset: Expr, optional 66 The offset of pointer. We can use it to offset by 67 the number of elements from the address of ptr. 68 69 Examples 70 -------- 71 .. code-block:: python 72 73 # Get access ptr for read 74 buffer.access_ptr("r") 75 # Get access ptr for read/write with bitmask 76 buffer.access_ptr(Buffer.READ | Buffer.WRITE) 77 # Get access ptr for read/write with str flag 78 buffer.access_ptr("rw") 79 # Get access ptr for read with offset 80 buffer.access_ptr("r", offset = 100) 81 """ 82 if isinstance(access_mask, string_types): 83 mask = 0 84 for value in access_mask: 85 if value == "r": 86 mask = mask | Buffer.READ 87 elif value == "w": 88 mask = mask | Buffer.WRITE 89 else: 90 raise ValueError("Unknown access_mask %s" % access_mask) 91 access_mask = mask 92 offset = convert(offset) 93 return _ffi_api.BufferAccessPtr(self, access_mask, ptr_type, content_lanes, offset) 94 95 def vload(self, begin, dtype=None): 96 """Generate an Expr that loads dtype from begin index. 97 98 Parameters 99 ---------- 100 begin : Array of Expr 101 The beginning index in unit of Buffer.dtype 102 103 dtype : str 104 The data type to be loaded, 105 can be vector type which have lanes that is multiple of Buffer.dtype 106 107 Returns 108 ------- 109 load : Expr 110 The corresponding load expression. 111 """ 112 begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin 113 dtype = dtype if dtype else self.dtype 114 return _ffi_api.BufferVLoad(self, begin, dtype) 115 116 def vstore(self, begin, value): 117 """Generate a Stmt that store value into begin index. 118 119 Parameters 120 ---------- 121 begin : Array of Expr 122 The beginning index in unit of Buffer.dtype 123 124 value : Expr 125 The value to be stored. 126 127 Returns 128 ------- 129 store : Stmt 130 The corresponding store stmt. 131 """ 132 begin = (begin,) if isinstance(begin, (int, PrimExpr)) else begin 133 return _ffi_api.BufferVStore(self, begin, value) 134 135 136def decl_buffer( 137 shape, 138 dtype=None, 139 name="buffer", 140 data=None, 141 strides=None, 142 elem_offset=None, 143 scope="", 144 data_alignment=-1, 145 offset_factor=0, 146 buffer_type="", 147): 148 """Declare a new symbolic buffer. 149 150 Normally buffer is created automatically during lower and build. 151 This is only needed if user want to specify their own buffer layout. 152 153 See the note below for detailed discussion on usage of buffer. 154 155 Parameters 156 ---------- 157 shape : tuple of Expr 158 The shape of the buffer. 159 160 dtype : str, optional 161 The data type of the buffer. 162 163 name : str, optional 164 The name of the buffer. 165 166 data : Var, optional 167 The data pointer in the buffer. 168 169 strides: array of Expr 170 The stride of the buffer. 171 172 elem_offset: Expr, optional 173 The beginning offset of the array to data. 174 In terms of number of elements of dtype. 175 176 scope: str, optional 177 The storage scope of the buffer, if not global. 178 If scope equals empty string, it means it is global memory. 179 180 data_alignment: int, optional 181 The alignment of data pointer in bytes. 182 If -1 is passed, the alignment will be set to TVM's internal default. 183 184 offset_factor: int, optional 185 The factor of elem_offset field, when set, 186 elem_offset is required to be multiple of offset_factor. 187 If 0 is pssed, the alignment will be set to 1. 188 if non-zero is passed, we will created a Var for elem_offset if elem_offset is not None. 189 190 buffer_type: str, optional, {"", "auto_broadcast"} 191 auto_broadcast buffer allows one to implement broadcast computation 192 without considering whether dimension size equals to one. 193 TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension j's shape equals 1. 194 195 Returns 196 ------- 197 buffer : Buffer 198 The created buffer 199 200 Example 201 ------- 202 Here's an example of how broadcast buffer can be used to define a symbolic broadcast operation, 203 204 .. code-block:: python 205 206 m0, m1, m2 = te.var("m0"), te.var("m1"), te.var("m2") 207 n0, n1, n2 = te.var("n0"), te.var("n1"), te.var("n2") 208 o0, o1, o2 = te.var("o0"), te.var("o1"), te.var("o2") 209 A = te.placeholder((m0, m1, m2), name='A') 210 B = te.placeholder((n0, n1, n2), name='B') 211 C = te.compute((o0, o1, o2), lambda i, j, k: A[i, j, k] + B[i, j, k], name='C') 212 Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name="Ab", buffer_type="auto_broadcast") 213 Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name="Bb", buffer_type="auto_broadcast") 214 s = te.create_schedule(C.op) 215 fadd = tvm.build(s, [A, B, C], target='llvm', name='bcast_add', binds={A:Ab, B:Bb}) 216 ctx = tvm.cpu(0) 217 a = tvm.nd.array(np.random.uniform(size=(2, 4, 3)).astype(A.dtype), ctx) 218 b = tvm.nd.array(np.random.uniform(size=(2, 1, 3)).astype(B.dtype), ctx) 219 c = tvm.nd.array(np.zeros((2, 4, 3), dtype=C.dtype), ctx) 220 fadd(a, b, c) 221 tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) 222 223 Note 224 ---- 225 Buffer data structure reflects the DLTensor structure in dlpack. 226 While DLTensor data structure is very general, it is usually helpful 227 to create function that only handles specific case of data structure 228 and make compiled function benefit from it. 229 230 If user pass strides and elem_offset is passed as None 231 when constructing the function, then the function will be specialized 232 for the DLTensor that is compact and aligned. 233 If user pass a fully generic symbolic array to the strides, 234 then the resulting function becomes fully generic. 235 """ 236 # pylint: disable=import-outside-toplevel 237 from .expr import Var 238 239 shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape 240 dtype = "float32" if dtype is None else dtype 241 strides = () if strides is None else strides 242 if offset_factor != 0 and elem_offset is None: 243 shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" 244 elem_offset = Var("%s_elem_offset" % name, shape_dtype) 245 if data is None: 246 data = Var(name, PointerType(PrimType(dtype))) 247 return _ffi_api.Buffer( 248 data, 249 dtype, 250 shape, 251 strides, 252 elem_offset, 253 name, 254 scope, 255 data_alignment, 256 offset_factor, 257 buffer_type, 258 ) 259 260 261@tvm._ffi.register_object("tir.DataProducer") 262class DataProducer(Object): 263 pass 264