1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/buffer.h 22 * \brief Symbolic n-dimensional array, to represent a memory buffer. 23 */ 24 #ifndef TVM_BUFFER_H_ 25 #define TVM_BUFFER_H_ 26 27 #include <string> 28 29 #include "base.h" 30 #include "expr.h" 31 #include "expr_operator.h" 32 #include "tvm/node/container.h" 33 34 namespace tvm { 35 36 // Internal node container Buffer 37 class BufferNode; 38 39 /*! \brief buffer type */ 40 enum BufferType : int { 41 kDefault = 1, 42 // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1. 43 kAutoBroadcast = 2, 44 }; 45 46 /*! 47 * \brief Buffer is a symbolic n-darray structure. 48 * It is a composition of primitive symbolic types, 49 * used to specify the memory layout of the Tensor used in program input. 50 */ 51 class Buffer : public NodeRef { 52 public: Buffer()53 Buffer() {} Buffer(ObjectPtr<Object> n)54 explicit Buffer(ObjectPtr<Object> n) : NodeRef(n) {} 55 /*! 56 * \brief Return a new buffer that is equivalent with current one 57 * but always add stride field. 58 * \return The strided version of the buffer. 59 */ 60 TVM_DLL Buffer MakeStrideView() const; 61 /*! 62 * \brief Make a new symbolic buffer representing a slice of the buffer. 63 * \param begins The beginning position of each dimension. 64 * \param extents The extent of each dimension. 65 * \note This function will make target buffer as compact as possible. 66 * If stride is not needed in the slice, it won't be presented 67 * \return the result buffer. 68 */ 69 TVM_DLL Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const; 70 /*! 71 * \brief Get access ptr to the entire buffer. 72 * \param access_mask The access mask 73 * \param ptr_type The type of the pointer. 74 * \param content_lanes The number of lanes for the (data) type. 75 * \param offset The offset of ptr. 76 */ 77 TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(), 78 int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const; 79 /*! 80 * \brief Create an Expr that does a vector load at begin index. 81 * \param begin The beginning index 82 * \param dtype The data type to be loaded. 83 */ 84 TVM_DLL Expr vload(Array<Expr> begin, Type dtype) const; 85 /*! 86 * \brief Create a Stmt that does a vector store at begin index. 87 * \param begin The beginning index 88 * \param value The value to be stored. 89 */ 90 TVM_DLL Stmt vstore(Array<Expr> begin, Expr value) const; 91 /*! 92 * \brief access the internal node container 93 * \return the pointer to the internal node container 94 */ 95 inline const BufferNode* operator->() const; 96 97 /*! \brief specify container node */ 98 using ContainerType = BufferNode; 99 }; 100 101 /*! \brief Node to represent a buffer */ 102 class BufferNode : public Node { 103 public: 104 // Data fields. 105 /*! 106 * \brief The pointer to the head of the data 107 * \sa data_alignment The alignment of data in bytes. 108 */ 109 Var data; 110 /*! \brief data type in the content of the tensor */ 111 Type dtype; 112 /*! \brief The shape of the buffer */ 113 Array<Expr> shape; 114 /*! 115 * \brief The strides of each dimension 116 * This can be an empty array, indicating array is contiguous 117 */ 118 Array<Expr> strides; 119 /*! \brief The offset in terms of number of dtype elements (including lanes) */ 120 Expr elem_offset; 121 // Meta data 122 /*! \brief optional name of the buffer */ 123 std::string name; 124 /*! \brief storage scope of the buffer, if other than global */ 125 std::string scope; 126 /*! \brief Alignment requirement of data pointer in bytes. */ 127 int data_alignment; 128 /*! 129 * \brief Factor of elem_offset field, 130 * elem_offset is guaranteed to be multiple of offset_factor. 131 */ 132 int offset_factor; 133 /*! \brief buffer type */ 134 BufferType buffer_type; 135 /*! \brief constructor */ BufferNode()136 BufferNode() {} 137 VisitAttrs(AttrVisitor * v)138 void VisitAttrs(AttrVisitor* v) { 139 v->Visit("data", &data); 140 v->Visit("dtype", &dtype); 141 v->Visit("shape", &shape); 142 v->Visit("strides", &strides); 143 v->Visit("elem_offset", &elem_offset); 144 v->Visit("name", &name); 145 v->Visit("scope", &scope); 146 v->Visit("data_alignment", &data_alignment); 147 v->Visit("offset_factor", &offset_factor); 148 v->Visit("buffer_type", &buffer_type); 149 } 150 151 /*! \return preferred index type for this buffer node */ DefaultIndexType()152 Type DefaultIndexType() const { 153 return shape.size() != 0 ? shape[0].type() : Int(32); 154 } 155 156 // User can specify data_alignment and offset_factor to be 0 157 // A default value will be picked. 158 TVM_DLL static Buffer make(Var ptr, 159 Type dtype, 160 Array<Expr> shape, 161 Array<Expr> strides, 162 Expr elem_offset, 163 std::string name, 164 std::string scope, 165 int data_alignment, 166 int offset_factor, 167 BufferType buffer_type); 168 169 static constexpr const char* _type_key = "Buffer"; 170 TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); 171 }; 172 173 inline const BufferNode* Buffer::operator->() const { 174 return static_cast<const BufferNode*>(get()); 175 } 176 177 /*! 178 * \brief Construct a new buffer given shape, and dtype. 179 * \param shape The shape of the buffer, 180 * \param dtype The content data type. 181 * \param name The name of the buffer 182 * \return The created buffer. 183 * \sa BufferNode::make for complete constructor. 184 */ 185 TVM_DLL Buffer decl_buffer(Array<Expr> shape, 186 Type dtype = Float(32), 187 std::string name = "buffer"); 188 } // namespace tvm 189 #endif // TVM_BUFFER_H_ 190