1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/buffer.h
22  * \brief Symbolic n-dimensional array, to represent a memory buffer.
23  */
24 #ifndef TVM_BUFFER_H_
25 #define TVM_BUFFER_H_
26 
27 #include <string>
28 
29 #include "base.h"
30 #include "expr.h"
31 #include "expr_operator.h"
32 #include "tvm/node/container.h"
33 
34 namespace tvm {
35 
36 // Internal node container Buffer
37 class BufferNode;
38 
39 /*! \brief buffer type */
40 enum BufferType : int {
41   kDefault = 1,
42   // Maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
43   kAutoBroadcast = 2,
44 };
45 
46 /*!
47  * \brief Buffer is a symbolic n-darray structure.
48  *  It is a composition of primitive symbolic types,
49  *  used to specify the memory layout of the Tensor used in program input.
50  */
51 class Buffer : public NodeRef {
52  public:
Buffer()53   Buffer() {}
Buffer(ObjectPtr<Object> n)54   explicit Buffer(ObjectPtr<Object> n) : NodeRef(n) {}
55   /*!
56    * \brief Return a new buffer that is equivalent with current one
57    *  but always add stride field.
58    * \return The strided version of the buffer.
59    */
60   TVM_DLL Buffer MakeStrideView() const;
61   /*!
62    * \brief Make a new symbolic buffer representing a slice of the buffer.
63    * \param begins The beginning position of each dimension.
64    * \param extents The extent of each dimension.
65    * \note This function will make target buffer as compact as possible.
66    *  If stride is not needed in the slice, it won't be presented
67    * \return the result buffer.
68    */
69   TVM_DLL Buffer MakeSlice(Array<Expr> begins, Array<Expr> extents) const;
70   /*!
71    * \brief Get access ptr to the entire buffer.
72    * \param access_mask The access mask
73    * \param ptr_type The type of the pointer.
74    * \param content_lanes The number of lanes for the (data) type.
75    * \param offset The offset of ptr.
76    */
77   TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
78                           int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const;
79   /*!
80    * \brief Create an Expr that does a vector load at begin index.
81    * \param begin The beginning index
82    * \param dtype The data type to be loaded.
83    */
84   TVM_DLL Expr vload(Array<Expr> begin, Type dtype) const;
85   /*!
86    * \brief Create a Stmt that does a vector store at begin index.
87    * \param begin The beginning index
88    * \param value The value to be stored.
89    */
90   TVM_DLL Stmt vstore(Array<Expr> begin, Expr value) const;
91   /*!
92    * \brief access the internal node container
93    * \return the pointer to the internal node container
94    */
95   inline const BufferNode* operator->() const;
96 
97   /*! \brief specify container node */
98   using ContainerType = BufferNode;
99 };
100 
101 /*! \brief Node to represent a buffer */
102 class BufferNode : public Node {
103  public:
104   // Data fields.
105   /*!
106    * \brief The pointer to the head of the data
107    * \sa data_alignment The alignment of data in bytes.
108    */
109   Var data;
110   /*! \brief data type in the content of the tensor */
111   Type dtype;
112   /*! \brief The shape of the buffer */
113   Array<Expr> shape;
114   /*!
115    * \brief The strides of each dimension
116    *  This can be an empty array, indicating array is contiguous
117    */
118   Array<Expr> strides;
119   /*! \brief The offset in terms of number of dtype elements (including lanes) */
120   Expr elem_offset;
121   // Meta data
122   /*! \brief optional name of the buffer */
123   std::string name;
124   /*! \brief storage scope of the buffer, if other than global */
125   std::string scope;
126   /*! \brief Alignment requirement of data pointer in bytes. */
127   int data_alignment;
128   /*!
129    * \brief Factor of elem_offset field,
130    *  elem_offset is guaranteed to be multiple of offset_factor.
131    */
132   int offset_factor;
133   /*! \brief buffer type */
134   BufferType buffer_type;
135   /*! \brief constructor */
BufferNode()136   BufferNode() {}
137 
VisitAttrs(AttrVisitor * v)138   void VisitAttrs(AttrVisitor* v) {
139     v->Visit("data", &data);
140     v->Visit("dtype", &dtype);
141     v->Visit("shape", &shape);
142     v->Visit("strides", &strides);
143     v->Visit("elem_offset", &elem_offset);
144     v->Visit("name", &name);
145     v->Visit("scope", &scope);
146     v->Visit("data_alignment", &data_alignment);
147     v->Visit("offset_factor", &offset_factor);
148     v->Visit("buffer_type", &buffer_type);
149   }
150 
151   /*! \return preferred index type for this buffer node */
DefaultIndexType()152   Type DefaultIndexType() const {
153     return shape.size() != 0 ? shape[0].type() : Int(32);
154   }
155 
156   // User can specify data_alignment and offset_factor to be 0
157   // A default value will be picked.
158   TVM_DLL static Buffer make(Var ptr,
159                              Type dtype,
160                              Array<Expr> shape,
161                              Array<Expr> strides,
162                              Expr elem_offset,
163                              std::string name,
164                              std::string scope,
165                              int data_alignment,
166                              int offset_factor,
167                              BufferType buffer_type);
168 
169   static constexpr const char* _type_key = "Buffer";
170   TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
171 };
172 
173 inline const BufferNode* Buffer::operator->() const {
174   return static_cast<const BufferNode*>(get());
175 }
176 
177 /*!
178  * \brief Construct a new buffer given shape, and dtype.
179  * \param shape The shape of the buffer,
180  * \param dtype The content data type.
181  * \param name The name of the buffer
182  * \return The created buffer.
183  * \sa BufferNode::make for complete constructor.
184  */
185 TVM_DLL Buffer decl_buffer(Array<Expr> shape,
186                            Type dtype = Float(32),
187                            std::string name = "buffer");
188 }  // namespace tvm
189 #endif  // TVM_BUFFER_H_
190