1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/data_layout.h
22  * \brief Layout expression to describe the data organization of a tensor.
23  *  And BijectiveLayout to mapping two data layouts between each other.
24  */
25 #ifndef TVM_DATA_LAYOUT_H_
26 #define TVM_DATA_LAYOUT_H_
27 
28 #include <tvm/base.h>
29 #include <tvm/expr.h>
30 
31 #include <string>
32 #include <sstream>
33 #include <vector>
34 #include <utility>
35 #include <algorithm>
36 
37 #include "expr_operator.h"
38 
39 namespace tvm {
40 
41 class LayoutAxis {
42  public:
43   static const LayoutAxis& Get(const char name);
44 
45   // Get the singleton LayoutAxis using itvar->var->name_hint
46   static const LayoutAxis& Get(const IterVar& itvar);
47 
48   // Get the singleton LayoutAxis using name[0] (size of name must be 1).
49   static const LayoutAxis& make(const std::string& name);
50 
IsPrimal()51   inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; }
name()52   inline std::string name() const { return std::string(1, name_); }
53 
54   // if current axis is primal, switch the axis to its subordinate one,
55   // else switch to the primal.
ToDual()56   inline const LayoutAxis& ToDual() const {
57     if (name_ >= 'A' && name_ <= 'Z') {
58       return LayoutAxis::Get(name_ - 'A' + 'a');
59     } else {
60       return LayoutAxis::Get(name_ - 'a' + 'A');
61     }
62   }
63 
64   // return the primal axis. If it is already primal, return itself.
ToPrimal()65   const LayoutAxis& ToPrimal() const {
66     return IsPrimal() ? *this : ToDual();
67   }
68 
69   // return the subordinate axis. If it is already subordinate, return itself.
ToSubordinate()70   const LayoutAxis& ToSubordinate() const {
71     return IsPrimal() ? ToDual() : *this;
72   }
73 
74   inline bool operator==(const LayoutAxis& rhs) const {
75     return name_ == rhs.name_;
76   }
77 
78   friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) {
79     os << l.name();
80     return os;
81   }
82 
83  private:
84   static const LayoutAxis UPPER_CASE[];
85   static const LayoutAxis LOWER_CASE[];
86   LayoutAxis(const LayoutAxis&);
87   LayoutAxis& operator=(const LayoutAxis&);
LayoutAxis(const char name)88   explicit LayoutAxis(const char name) : name_(name) {}
89 
90   const char name_;
91 };
92 
93 class Layout;
94 // Internal node container Buffer
95 class LayoutNode : public Node {
96  public:
97   /*! \brief string representation of layout, "" for scalar. */
98   std::string name;
99   /*! \brief specify each axis of the layout,
100    *   in which the variable name is the name of the axis.
101    *   The IterVar's extent indicates the size of the axis,
102    *   it is a variable for a primal axis, but a constant for a subordinate axis.
103    *   Empty for scalar's layout.
104    */
105   Array<IterVar> axes;
106 
VisitAttrs(AttrVisitor * v)107   void VisitAttrs(AttrVisitor* v) {
108     v->Visit("name", &name);
109     v->Visit("axes", &axes);
110   }
111 
112   TVM_DLL static Layout make(const std::string& layout);
113 
114   static constexpr const char* _type_key = "Layout";
115   TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node);
116 };
117 
118 /*!
119  * \brief Layout is to describe how data is organized within an N-dimention tensor.
120  *  It is composed of upper cases, lower cases and numbers,
121  *  where upper case indicates a primal axis and
122  *  the corresponding lower case with factor size indicates the subordinate axis.
123  *  For example, NCHW16c can describe a 5-D tensor of
124  *  [batch_size, channel, height, width, channel_block].
125  *  Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel).
126  *  Layout for scalar is defined, while both its name and axes have size 0.
127  */
128 class Layout : public NodeRef {
129  public:
Layout(ObjectPtr<Object> n)130   explicit Layout(ObjectPtr<Object> n) : NodeRef(n) {}
131 
132   /*! \brief default constructor */
133   Layout() = default;
134 
135   explicit Layout(const Array<IterVar>& axes);
136 
137   /*! \brief construct from a string */
Layout(const char * name)138   Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*)
139 
140   /*!
141    * \brief construct from a string.
142    * \param name input in layout convention:
143    *        upper case indicates a dimension and
144    *        the corresponding lower case with factor size
145    *        indicates the split dimension.
146    *        return undefined layout if "__undef__" is passed.
147    */
148   Layout(const std::string& name); // NOLINT(*)
149 
150   /*!
151    * \brief access the internal node container
152    * \return the pointer to the internal node container
153    */
154   const LayoutNode* operator->() const {
155     return static_cast<const LayoutNode*>(get());
156   }
157 
158   /*!
159    * \brief access the internal node container
160    * \return the pointer to the internal node container
161    */
162   LayoutNode* operator->() {
163     return static_cast<LayoutNode*>(get_mutable());
164   }
165 
166   /*!
167    * \brief Return an undefined layout.
168    * \return a (global) undefined layout.
169    */
Undef()170   static const Layout& Undef() {
171     static Layout undef;
172     return undef;
173   }
174 
175   /*!
176    * \brief Returns a sub-layout which is the portion of the object
177    *        that starts at dimension \p pos and spans \p len dimensions
178    *        (or until the end of the layout, whichever comes first).
179    * \param pos The start position.
180    * \param len The length of the sub-layout. if 0, return layout of scalar
181    * \return A newly constructed Layout object.
182    */
183   Layout SubLayout(size_t pos, size_t len) const;
184 
185   /*!
186    * \brief Split \p axis by \p size and put the sub-axis to position \p target_pos.
187    * \param axis The source axis to be split. It must be a primal-axis;
188    * \param target_pos The target position of the newly split subordinate-axis.
189    * \param factor size of the sub-dimension.
190    * \return A newly constructed Layout object.
191    */
192   Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const;
193 
194 
195   /*! \return number of dimensions */
ndim()196   inline size_t ndim() const {
197     if (!defined()) return 0;
198     return operator->()->axes.size();
199   }
200 
201   /*! \return number of super dimensions */
ndim_primal()202   inline size_t ndim_primal() const {
203     if (!defined()) return 0;
204     size_t ct = 0;
205     for (auto x : operator->()->axes) {
206       if (LayoutAxis::Get(x).IsPrimal()) {
207         ct++;
208       }
209     }
210     return ct;
211   }
212 
213   /*!
214    * \brief Returns a new layout where the dims have been expanded to match the primal dimensions.
215    * \param dst_layout The dst layout to which current layout has to be expanded.
216    * \return The expanded Layout.
217    */
ExpandPrimal(const Layout & dst_layout)218   inline Layout ExpandPrimal(const Layout& dst_layout) {
219     Layout new_src_layout;
220     // 1) Find the axis which are missing in the current layout. Make them the prefix.
221     std::string new_src_layout_str = "";
222     for (auto dst_axis : dst_layout->axes) {
223       if (LayoutAxis::Get(dst_axis).IsPrimal()) {
224         if (!this->Contains(LayoutAxis::Get(dst_axis))) {
225           new_src_layout_str += dst_axis->var->name_hint;
226         }
227       }
228     }
229     // 2) Now, add the primal axis of the current layout.
230     new_src_layout_str += this->name();
231     new_src_layout = Layout(new_src_layout_str);
232     return new_src_layout;
233   }
234 
235   /*!
236    * \brief return the index of the input axis.
237    *        If it is not found in the layout or the layout is undefined,
238    *        return -1.
239    * \param axis the input axis.
240    * \return the index or -1 if not found.
241    */
IndexOf(const LayoutAxis & axis)242   inline int32_t IndexOf(const LayoutAxis& axis) const {
243     if (!this->defined()) return -1;
244     const auto axes = operator->()->axes;
245     for (size_t i = 0; i < axes.size(); ++i) {
246       if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i);
247     }
248     return -1;
249   }
250 
251   /*!
252    * \brief Get the factor size of the subordinate axis.
253    * \param axis the input primal-axis or subordinate-axis.
254    * \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis),
255    *         or the size of \p axis itself (if \p axis is a subordinate-axis).
256    *         Return -1 if \p axis is not in the layout the layout is undefined.
257    */
258   int32_t FactorOf(const LayoutAxis& axis) const;
259 
260   /*!
261    * \brief Whether the layout contains an axis.
262    * \param axis axis to be checked.
263    * \return Whether the layout contains the axis.
264    */
Contains(const LayoutAxis & axis)265   bool Contains(const LayoutAxis& axis) const {
266     if (!defined()) return false;
267     for (const IterVar var : operator->()->axes) {
268       if (var->var->name_hint == axis.name()) {
269         return true;
270       }
271     }
272     return false;
273   }
274 
275   const LayoutAxis& operator[](int32_t i) const {
276     CHECK(defined()) << "Try to access axis from an undefined layout.";
277     int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i;
278     CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i;
279     const IterVar axis = operator->()->axes[index];
280     return LayoutAxis::Get(axis);
281   }
282 
283   /*! \return the string description of the layout */
name()284   inline std::string name() const {
285     if (!defined()) return "__undef__";
286     return operator->()->name;
287   }
288 
289   /*!
290    * \brief Whether the two layouts are equal.
291    * \param rhs Another layout.
292    * \return whether the two layouts are equal.
293    */
Equals(const Layout & rhs)294   inline bool Equals(const Layout &rhs) const {
295     return name() == rhs.name();
296   }
297 
298   /*!
299    * \brief allow output string of layout to ostream
300    * \param os the output stream
301    * \param l the layout
302    * \return the ostream
303    */
304   friend std::ostream& operator<<(std::ostream& os, const Layout& l) {
305     os << l.name();
306     return os;
307   }
308 
309   using ContainerType = LayoutNode;
310 };
311 
312 class BijectiveLayout;
313 // Internal node container BijectiveLayout
314 class BijectiveLayoutNode : public Node {
315  public:
316   /*! \brief Describes how source axes can be mapped to the destination axes,
317    *   e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n
318    */
319   Array<Expr> forward_rule;
320   /*! \brief Describes how destination axes can be mapped to the source axes */
321   Array<Expr> backward_rule;
322 
323   /*! \brief The source layout */
324   Layout src_layout;
325   /*! \brief The destination layout */
326   Layout dst_layout;
327 
VisitAttrs(AttrVisitor * v)328   void VisitAttrs(AttrVisitor* v) {
329     v->Visit("src_layout", &src_layout);
330     v->Visit("dst_layout", &dst_layout);
331     v->Visit("forward_rule", &forward_rule);
332     v->Visit("backward_rule", &backward_rule);
333   }
334 
335   static constexpr const char* _type_key = "BijectiveLayout";
336   TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node);
337 
338   TVM_DLL static BijectiveLayout make(const Layout& src_layout,
339                                       const Layout& dst_layout);
340 };
341 
342 /*! \brief Bijective function mapping for data layout transformation.
343  *   Given two Layout, BijectiveLayout build and store the mapping rules,
344  *   provides API to transform N-dimention tensor from the source indices (i0, i1, …, im)
345  *   to the destination indices (j0, j1, … jm).
346  */
347 class BijectiveLayout : public NodeRef {
348  public:
349   BijectiveLayout() = default;
BijectiveLayout(NodePtr<Node> n)350   explicit BijectiveLayout(NodePtr<Node> n) : NodeRef(n) {}
351 
352   // Given the source shape, infer the destination shape.
353   TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const;
354   // Given the destination shape, recover the source shape.
355   TVM_DLL Array<Expr> BackwardShape(const Array<Expr>& dst_shape) const;
356   // Given the destination indices, infer the destination indices.
357   TVM_DLL Array<Expr> ForwardIndex(const Array<Expr>& index) const;
358   // Given the destination indices, recover the source indices.
359   TVM_DLL Array<Expr> BackwardIndex(const Array<Expr>& dst_index) const;
360 
361   /*!
362    * \brief access the internal node container
363    * \return the pointer to the internal node container
364    */
365   inline const BijectiveLayoutNode* operator->() const;
366 
367   /*! \brief specify container node */
368   using ContainerType = BijectiveLayoutNode;
369 };
370 
371 inline const BijectiveLayoutNode* BijectiveLayout::operator->() const {
372   return static_cast<const BijectiveLayoutNode*>(get());
373 }
374 
375 }  // namespace tvm
376 
377 #endif  // TVM_DATA_LAYOUT_H_
378