1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file tvm/data_layout.h 22 * \brief Layout expression to describe the data organization of a tensor. 23 * And BijectiveLayout to mapping two data layouts between each other. 24 */ 25 #ifndef TVM_DATA_LAYOUT_H_ 26 #define TVM_DATA_LAYOUT_H_ 27 28 #include <tvm/base.h> 29 #include <tvm/expr.h> 30 31 #include <string> 32 #include <sstream> 33 #include <vector> 34 #include <utility> 35 #include <algorithm> 36 37 #include "expr_operator.h" 38 39 namespace tvm { 40 41 class LayoutAxis { 42 public: 43 static const LayoutAxis& Get(const char name); 44 45 // Get the singleton LayoutAxis using itvar->var->name_hint 46 static const LayoutAxis& Get(const IterVar& itvar); 47 48 // Get the singleton LayoutAxis using name[0] (size of name must be 1). 49 static const LayoutAxis& make(const std::string& name); 50 IsPrimal()51 inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; } name()52 inline std::string name() const { return std::string(1, name_); } 53 54 // if current axis is primal, switch the axis to its subordinate one, 55 // else switch to the primal. ToDual()56 inline const LayoutAxis& ToDual() const { 57 if (name_ >= 'A' && name_ <= 'Z') { 58 return LayoutAxis::Get(name_ - 'A' + 'a'); 59 } else { 60 return LayoutAxis::Get(name_ - 'a' + 'A'); 61 } 62 } 63 64 // return the primal axis. If it is already primal, return itself. ToPrimal()65 const LayoutAxis& ToPrimal() const { 66 return IsPrimal() ? *this : ToDual(); 67 } 68 69 // return the subordinate axis. If it is already subordinate, return itself. ToSubordinate()70 const LayoutAxis& ToSubordinate() const { 71 return IsPrimal() ? ToDual() : *this; 72 } 73 74 inline bool operator==(const LayoutAxis& rhs) const { 75 return name_ == rhs.name_; 76 } 77 78 friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) { 79 os << l.name(); 80 return os; 81 } 82 83 private: 84 static const LayoutAxis UPPER_CASE[]; 85 static const LayoutAxis LOWER_CASE[]; 86 LayoutAxis(const LayoutAxis&); 87 LayoutAxis& operator=(const LayoutAxis&); LayoutAxis(const char name)88 explicit LayoutAxis(const char name) : name_(name) {} 89 90 const char name_; 91 }; 92 93 class Layout; 94 // Internal node container Buffer 95 class LayoutNode : public Node { 96 public: 97 /*! \brief string representation of layout, "" for scalar. */ 98 std::string name; 99 /*! \brief specify each axis of the layout, 100 * in which the variable name is the name of the axis. 101 * The IterVar's extent indicates the size of the axis, 102 * it is a variable for a primal axis, but a constant for a subordinate axis. 103 * Empty for scalar's layout. 104 */ 105 Array<IterVar> axes; 106 VisitAttrs(AttrVisitor * v)107 void VisitAttrs(AttrVisitor* v) { 108 v->Visit("name", &name); 109 v->Visit("axes", &axes); 110 } 111 112 TVM_DLL static Layout make(const std::string& layout); 113 114 static constexpr const char* _type_key = "Layout"; 115 TVM_DECLARE_NODE_TYPE_INFO(LayoutNode, Node); 116 }; 117 118 /*! 119 * \brief Layout is to describe how data is organized within an N-dimention tensor. 120 * It is composed of upper cases, lower cases and numbers, 121 * where upper case indicates a primal axis and 122 * the corresponding lower case with factor size indicates the subordinate axis. 123 * For example, NCHW16c can describe a 5-D tensor of 124 * [batch_size, channel, height, width, channel_block]. 125 * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). 126 * Layout for scalar is defined, while both its name and axes have size 0. 127 */ 128 class Layout : public NodeRef { 129 public: Layout(ObjectPtr<Object> n)130 explicit Layout(ObjectPtr<Object> n) : NodeRef(n) {} 131 132 /*! \brief default constructor */ 133 Layout() = default; 134 135 explicit Layout(const Array<IterVar>& axes); 136 137 /*! \brief construct from a string */ Layout(const char * name)138 Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) 139 140 /*! 141 * \brief construct from a string. 142 * \param name input in layout convention: 143 * upper case indicates a dimension and 144 * the corresponding lower case with factor size 145 * indicates the split dimension. 146 * return undefined layout if "__undef__" is passed. 147 */ 148 Layout(const std::string& name); // NOLINT(*) 149 150 /*! 151 * \brief access the internal node container 152 * \return the pointer to the internal node container 153 */ 154 const LayoutNode* operator->() const { 155 return static_cast<const LayoutNode*>(get()); 156 } 157 158 /*! 159 * \brief access the internal node container 160 * \return the pointer to the internal node container 161 */ 162 LayoutNode* operator->() { 163 return static_cast<LayoutNode*>(get_mutable()); 164 } 165 166 /*! 167 * \brief Return an undefined layout. 168 * \return a (global) undefined layout. 169 */ Undef()170 static const Layout& Undef() { 171 static Layout undef; 172 return undef; 173 } 174 175 /*! 176 * \brief Returns a sub-layout which is the portion of the object 177 * that starts at dimension \p pos and spans \p len dimensions 178 * (or until the end of the layout, whichever comes first). 179 * \param pos The start position. 180 * \param len The length of the sub-layout. if 0, return layout of scalar 181 * \return A newly constructed Layout object. 182 */ 183 Layout SubLayout(size_t pos, size_t len) const; 184 185 /*! 186 * \brief Split \p axis by \p size and put the sub-axis to position \p target_pos. 187 * \param axis The source axis to be split. It must be a primal-axis; 188 * \param target_pos The target position of the newly split subordinate-axis. 189 * \param factor size of the sub-dimension. 190 * \return A newly constructed Layout object. 191 */ 192 Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const; 193 194 195 /*! \return number of dimensions */ ndim()196 inline size_t ndim() const { 197 if (!defined()) return 0; 198 return operator->()->axes.size(); 199 } 200 201 /*! \return number of super dimensions */ ndim_primal()202 inline size_t ndim_primal() const { 203 if (!defined()) return 0; 204 size_t ct = 0; 205 for (auto x : operator->()->axes) { 206 if (LayoutAxis::Get(x).IsPrimal()) { 207 ct++; 208 } 209 } 210 return ct; 211 } 212 213 /*! 214 * \brief Returns a new layout where the dims have been expanded to match the primal dimensions. 215 * \param dst_layout The dst layout to which current layout has to be expanded. 216 * \return The expanded Layout. 217 */ ExpandPrimal(const Layout & dst_layout)218 inline Layout ExpandPrimal(const Layout& dst_layout) { 219 Layout new_src_layout; 220 // 1) Find the axis which are missing in the current layout. Make them the prefix. 221 std::string new_src_layout_str = ""; 222 for (auto dst_axis : dst_layout->axes) { 223 if (LayoutAxis::Get(dst_axis).IsPrimal()) { 224 if (!this->Contains(LayoutAxis::Get(dst_axis))) { 225 new_src_layout_str += dst_axis->var->name_hint; 226 } 227 } 228 } 229 // 2) Now, add the primal axis of the current layout. 230 new_src_layout_str += this->name(); 231 new_src_layout = Layout(new_src_layout_str); 232 return new_src_layout; 233 } 234 235 /*! 236 * \brief return the index of the input axis. 237 * If it is not found in the layout or the layout is undefined, 238 * return -1. 239 * \param axis the input axis. 240 * \return the index or -1 if not found. 241 */ IndexOf(const LayoutAxis & axis)242 inline int32_t IndexOf(const LayoutAxis& axis) const { 243 if (!this->defined()) return -1; 244 const auto axes = operator->()->axes; 245 for (size_t i = 0; i < axes.size(); ++i) { 246 if (axes[i]->var->name_hint == axis.name()) return static_cast<int32_t>(i); 247 } 248 return -1; 249 } 250 251 /*! 252 * \brief Get the factor size of the subordinate axis. 253 * \param axis the input primal-axis or subordinate-axis. 254 * \return the size of the subordinate-axis of \p axis (if \p axis is a primal-axis), 255 * or the size of \p axis itself (if \p axis is a subordinate-axis). 256 * Return -1 if \p axis is not in the layout the layout is undefined. 257 */ 258 int32_t FactorOf(const LayoutAxis& axis) const; 259 260 /*! 261 * \brief Whether the layout contains an axis. 262 * \param axis axis to be checked. 263 * \return Whether the layout contains the axis. 264 */ Contains(const LayoutAxis & axis)265 bool Contains(const LayoutAxis& axis) const { 266 if (!defined()) return false; 267 for (const IterVar var : operator->()->axes) { 268 if (var->var->name_hint == axis.name()) { 269 return true; 270 } 271 } 272 return false; 273 } 274 275 const LayoutAxis& operator[](int32_t i) const { 276 CHECK(defined()) << "Try to access axis from an undefined layout."; 277 int32_t index = i < 0 ? static_cast<int32_t>(ndim() + i) : i; 278 CHECK(index >= 0 && static_cast<size_t>(index) < ndim()) << "Invalid index " << i; 279 const IterVar axis = operator->()->axes[index]; 280 return LayoutAxis::Get(axis); 281 } 282 283 /*! \return the string description of the layout */ name()284 inline std::string name() const { 285 if (!defined()) return "__undef__"; 286 return operator->()->name; 287 } 288 289 /*! 290 * \brief Whether the two layouts are equal. 291 * \param rhs Another layout. 292 * \return whether the two layouts are equal. 293 */ Equals(const Layout & rhs)294 inline bool Equals(const Layout &rhs) const { 295 return name() == rhs.name(); 296 } 297 298 /*! 299 * \brief allow output string of layout to ostream 300 * \param os the output stream 301 * \param l the layout 302 * \return the ostream 303 */ 304 friend std::ostream& operator<<(std::ostream& os, const Layout& l) { 305 os << l.name(); 306 return os; 307 } 308 309 using ContainerType = LayoutNode; 310 }; 311 312 class BijectiveLayout; 313 // Internal node container BijectiveLayout 314 class BijectiveLayoutNode : public Node { 315 public: 316 /*! \brief Describes how source axes can be mapped to the destination axes, 317 * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n 318 */ 319 Array<Expr> forward_rule; 320 /*! \brief Describes how destination axes can be mapped to the source axes */ 321 Array<Expr> backward_rule; 322 323 /*! \brief The source layout */ 324 Layout src_layout; 325 /*! \brief The destination layout */ 326 Layout dst_layout; 327 VisitAttrs(AttrVisitor * v)328 void VisitAttrs(AttrVisitor* v) { 329 v->Visit("src_layout", &src_layout); 330 v->Visit("dst_layout", &dst_layout); 331 v->Visit("forward_rule", &forward_rule); 332 v->Visit("backward_rule", &backward_rule); 333 } 334 335 static constexpr const char* _type_key = "BijectiveLayout"; 336 TVM_DECLARE_NODE_TYPE_INFO(BijectiveLayoutNode, Node); 337 338 TVM_DLL static BijectiveLayout make(const Layout& src_layout, 339 const Layout& dst_layout); 340 }; 341 342 /*! \brief Bijective function mapping for data layout transformation. 343 * Given two Layout, BijectiveLayout build and store the mapping rules, 344 * provides API to transform N-dimention tensor from the source indices (i0, i1, …, im) 345 * to the destination indices (j0, j1, … jm). 346 */ 347 class BijectiveLayout : public NodeRef { 348 public: 349 BijectiveLayout() = default; BijectiveLayout(NodePtr<Node> n)350 explicit BijectiveLayout(NodePtr<Node> n) : NodeRef(n) {} 351 352 // Given the source shape, infer the destination shape. 353 TVM_DLL Array<Expr> ForwardShape(const Array<Expr>& shape) const; 354 // Given the destination shape, recover the source shape. 355 TVM_DLL Array<Expr> BackwardShape(const Array<Expr>& dst_shape) const; 356 // Given the destination indices, infer the destination indices. 357 TVM_DLL Array<Expr> ForwardIndex(const Array<Expr>& index) const; 358 // Given the destination indices, recover the source indices. 359 TVM_DLL Array<Expr> BackwardIndex(const Array<Expr>& dst_index) const; 360 361 /*! 362 * \brief access the internal node container 363 * \return the pointer to the internal node container 364 */ 365 inline const BijectiveLayoutNode* operator->() const; 366 367 /*! \brief specify container node */ 368 using ContainerType = BijectiveLayoutNode; 369 }; 370 371 inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { 372 return static_cast<const BijectiveLayoutNode*>(get()); 373 } 374 375 } // namespace tvm 376 377 #endif // TVM_DATA_LAYOUT_H_ 378