1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/tir/expr.h
22  * \brief TIR expressions.
23  */
24 // Acknowledgement: Many low-level IR nodes originate from Halide.
25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
27 
28 #include <tvm/ir/expr.h>
29 #include <tvm/node/container.h>
30 #include <tvm/node/functor.h>
31 #include <tvm/node/node.h>
32 #include <tvm/runtime/c_runtime_api.h>
33 #include <tvm/runtime/data_type.h>
34 #include <tvm/tir/buffer.h>
35 #include <tvm/tir/var.h>
36 
37 #include <algorithm>
38 #include <iostream>
39 #include <limits>
40 #include <string>
41 #include <unordered_map>
42 #include <utility>
43 
44 namespace tvm {
45 namespace tir {
46 
47 using IntImmNode = tvm::IntImmNode;
48 using FloatImmNode = tvm::FloatImmNode;
49 
50 /*! \brief String constants, only used in asserts. */
51 class StringImmNode : public PrimExprNode {
52  public:
53   /*! \brief The constant value content. */
54   String value;
55 
VisitAttrs(AttrVisitor * v)56   void VisitAttrs(AttrVisitor* v) {
57     v->Visit("dtype", &dtype);
58     v->Visit("value", &value);
59   }
60 
SEqualReduce(const StringImmNode * other,SEqualReducer equal)61   bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
62     return equal(value, other->value);
63   }
64 
SHashReduce(SHashReducer hash_reduce)65   void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
66 
67   static constexpr const char* _type_key = "tir.StringImm";
68   TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode);
69 };
70 
71 /*!
72  * \brief Managed reference to StringImmNode.
73  * \sa StringImmNode
74  */
75 class StringImm : public PrimExpr {
76  public:
77   TVM_DLL StringImm(String value);
78   TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode);
79 };
80 
81 /*!
82  * \brief Cast value from one data type to another.
83  * \note The lanes of value should keep fixed.
84  */
85 class CastNode : public PrimExprNode {
86  public:
87   /*! \brief Original data type. */
88   PrimExpr value;
89 
VisitAttrs(AttrVisitor * v)90   void VisitAttrs(AttrVisitor* v) {
91     v->Visit("dtype", &dtype);
92     v->Visit("value", &value);
93   }
94 
SEqualReduce(const CastNode * other,SEqualReducer equal)95   bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
96     return equal(dtype, other->dtype) && equal(value, other->value);
97   }
98 
SHashReduce(SHashReducer hash_reduce)99   void SHashReduce(SHashReducer hash_reduce) const {
100     hash_reduce(dtype);
101     hash_reduce(value);
102   }
103 
104   static constexpr const char* _type_key = "tir.Cast";
105   TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode);
106 };
107 
108 /*!
109  * \brief Managed reference to CastNode
110  * \sa CastNode
111  */
112 class Cast : public PrimExpr {
113  public:
114   TVM_DLL Cast(DataType dtype, PrimExpr value);
115   TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode);
116 };
117 
118 /*!
119  * \brief Base template to implement binary ops.
120  * \tparam T The type of the child class.
121  */
122 template <typename T>
123 class BinaryOpNode : public PrimExprNode {
124  public:
125   /*! \brief The left operand. */
126   PrimExpr a;
127   /*! \brief The right operand. */
128   PrimExpr b;
129 
VisitAttrs(AttrVisitor * v)130   void VisitAttrs(AttrVisitor* v) {
131     v->Visit("dtype", &(this->dtype));
132     v->Visit("a", &a);
133     v->Visit("b", &b);
134   }
135 
SEqualReduce(const T * other,SEqualReducer equal)136   bool SEqualReduce(const T* other, SEqualReducer equal) const {
137     return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
138   }
139 
SHashReduce(SHashReducer hash_reduce)140   void SHashReduce(SHashReducer hash_reduce) const {
141     hash_reduce(dtype);
142     hash_reduce(a);
143     hash_reduce(b);
144   }
145 
146   TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
147 };
148 
149 /*! \brief a + b */
150 class AddNode : public BinaryOpNode<AddNode> {
151  public:
152   static constexpr const char* _type_key = "tir.Add";
153 };
154 
155 /*!
156  * \brief Managed reference to AddNode
157  * \sa AddNode
158  */
159 class Add : public PrimExpr {
160  public:
161   TVM_DLL Add(PrimExpr a, PrimExpr b);
162   TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode);
163 };
164 
165 /*! \brief a - b */
166 class SubNode : public BinaryOpNode<SubNode> {
167  public:
168   static constexpr const char* _type_key = "tir.Sub";
169 };
170 
171 /*!
172  * \brief Managed reference to SubNode
173  * \sa SubNode
174  */
175 class Sub : public PrimExpr {
176  public:
177   TVM_DLL Sub(PrimExpr a, PrimExpr b);
178   TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode);
179 };
180 
181 /*! \brief a * b */
182 class MulNode : public BinaryOpNode<MulNode> {
183  public:
184   static constexpr const char* _type_key = "tir.Mul";
185 };
186 
187 /*!
188  * \brief Managed reference to MulNode
189  * \sa MulNode
190  */
191 class Mul : public PrimExpr {
192  public:
193   TVM_DLL Mul(PrimExpr a, PrimExpr b);
194   TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode);
195 };
196 
197 /*!
198  * \brief a / b in the C semnatics.
199  * \note For integer division, C standard uses trunc div.
200  */
201 class DivNode : public BinaryOpNode<DivNode> {
202  public:
203   static constexpr const char* _type_key = "tir.Div";
204 };
205 
206 /*!
207  * \brief Managed reference to DivNode
208  * \sa DivNode
209  */
210 class Div : public PrimExpr {
211  public:
212   TVM_DLL Div(PrimExpr a, PrimExpr b);
213   TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode);
214 };
215 
216 /*!
217  * \brief a % b in the C semnatics.
218  * \note For integer division, C standard uses trunc div.
219  */
220 class ModNode : public BinaryOpNode<ModNode> {
221  public:
222   static constexpr const char* _type_key = "tir.Mod";
223 };
224 
225 /*!
226  * \brief Managed reference to ModNode
227  * \sa ModNode
228  */
229 class Mod : public PrimExpr {
230  public:
231   TVM_DLL Mod(PrimExpr a, PrimExpr b);
232   TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode);
233 };
234 
235 /*! \brief Floor division, floor(a/b) */
236 class FloorDivNode : public BinaryOpNode<FloorDivNode> {
237  public:
238   static constexpr const char* _type_key = "tir.FloorDiv";
239 };
240 
241 /*!
242  * \brief Managed reference to FloorDivNode
243  * \sa FloorDivNode
244  */
245 class FloorDiv : public PrimExpr {
246  public:
247   TVM_DLL FloorDiv(PrimExpr a, PrimExpr b);
248   TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode);
249 };
250 
251 /*! \brief The remainder of the floordiv */
252 class FloorModNode : public BinaryOpNode<FloorModNode> {
253  public:
254   static constexpr const char* _type_key = "tir.FloorMod";
255 };
256 
257 /*!
258  * \brief Managed reference to FloorModNode
259  * \sa FloorModNode
260  */
261 class FloorMod : public PrimExpr {
262  public:
263   TVM_DLL FloorMod(PrimExpr a, PrimExpr b);
264   TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode);
265 };
266 
267 /*! \brief min(a, b) */
268 class MinNode : public BinaryOpNode<MinNode> {
269  public:
270   static constexpr const char* _type_key = "tir.Min";
271 };
272 
273 /*!
274  * \brief Managed reference to MinNode
275  * \sa MinNode
276  */
277 class Min : public PrimExpr {
278  public:
279   TVM_DLL Min(PrimExpr a, PrimExpr b);
280   TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode);
281 };
282 
283 /*! \brief max(a, b) */
284 class MaxNode : public BinaryOpNode<MaxNode> {
285  public:
286   static constexpr const char* _type_key = "tir.Max";
287 };
288 
289 /*!
290  * \brief Managed reference to MaxNode
291  * \sa MaxNode
292  */
293 class Max : public PrimExpr {
294  public:
295   TVM_DLL Max(PrimExpr a, PrimExpr b);
296   TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode);
297 };
298 
299 /*!
300  * \brief Base template to implement comparison ops.
301  * \tparam T The type of the child class.
302  */
303 template <typename T>
304 class CmpOpNode : public PrimExprNode {
305  public:
306   /*! \brief The left operand. */
307   PrimExpr a;
308   /*! \brief The right operand. */
309   PrimExpr b;
310 
VisitAttrs(AttrVisitor * v)311   void VisitAttrs(AttrVisitor* v) {
312     v->Visit("dtype", &(this->dtype));
313     v->Visit("a", &a);
314     v->Visit("b", &b);
315   }
316 
SEqualReduce(const T * other,SEqualReducer equal)317   bool SEqualReduce(const T* other, SEqualReducer equal) const {
318     return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
319   }
320 
SHashReduce(SHashReducer hash_reduce)321   void SHashReduce(SHashReducer hash_reduce) const {
322     hash_reduce(dtype);
323     hash_reduce(a);
324     hash_reduce(b);
325   }
326 
327   TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
328 };
329 
330 /*! \brief a == b */
331 class EQNode : public CmpOpNode<EQNode> {
332  public:
333   static constexpr const char* _type_key = "tir.EQ";
334 };
335 
336 /*!
337  * \brief Managed reference to EQNode
338  * \sa EQNode
339  */
340 class EQ : public PrimExpr {
341  public:
342   TVM_DLL EQ(PrimExpr a, PrimExpr b);
343   TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode);
344 };
345 
346 /*! \brief a != b */
347 class NENode : public CmpOpNode<NENode> {
348  public:
349   static constexpr const char* _type_key = "tir.NE";
350 };
351 
352 /*!
353  * \brief Managed reference to NENode
354  * \sa NENode
355  */
356 class NE : public PrimExpr {
357  public:
358   TVM_DLL NE(PrimExpr a, PrimExpr b);
359   TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode);
360 };
361 
362 /*! \brief a < b */
363 class LTNode : public CmpOpNode<LTNode> {
364  public:
365   static constexpr const char* _type_key = "tir.LT";
366 };
367 
368 /*!
369  * \brief Managed reference to LTNode
370  * \sa LTNode
371  */
372 class LT : public PrimExpr {
373  public:
374   TVM_DLL LT(PrimExpr a, PrimExpr b);
375   TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode);
376 };
377 
378 /*! \brief a <= b */
379 struct LENode : public CmpOpNode<LENode> {
380  public:
381   static constexpr const char* _type_key = "tir.LE";
382 };
383 
384 /*!
385  * \brief Managed reference to LENode
386  * \sa LENode
387  */
388 class LE : public PrimExpr {
389  public:
390   TVM_DLL LE(PrimExpr a, PrimExpr b);
391   TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode);
392 };
393 
394 /*! \brief a > b */
395 class GTNode : public CmpOpNode<GTNode> {
396  public:
397   static constexpr const char* _type_key = "tir.GT";
398 };
399 
400 /*!
401  * \brief Managed reference to GTNode
402  * \sa GTNode
403  */
404 class GT : public PrimExpr {
405  public:
406   TVM_DLL GT(PrimExpr a, PrimExpr b);
407   TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode);
408 };
409 
410 /*! \brief a >= b */
411 class GENode : public CmpOpNode<GENode> {
412  public:
413   static constexpr const char* _type_key = "tir.GE";
414 };
415 
416 /*!
417  * \brief Managed reference to GENode
418  * \sa GENode
419  */
420 class GE : public PrimExpr {
421  public:
422   TVM_DLL GE(PrimExpr a, PrimExpr b);
423   TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode);
424 };
425 
426 /*! \brief a && b */
427 class AndNode : public PrimExprNode {
428  public:
429   /*! \brief The left operand. */
430   PrimExpr a;
431   /*! \brief The right operand. */
432   PrimExpr b;
433 
VisitAttrs(AttrVisitor * v)434   void VisitAttrs(AttrVisitor* v) {
435     v->Visit("dtype", &(this->dtype));
436     v->Visit("a", &a);
437     v->Visit("b", &b);
438   }
439 
SEqualReduce(const AndNode * other,SEqualReducer equal)440   bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
441     return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
442   }
443 
SHashReduce(SHashReducer hash_reduce)444   void SHashReduce(SHashReducer hash_reduce) const {
445     hash_reduce(dtype);
446     hash_reduce(a);
447     hash_reduce(b);
448   }
449 
450   static constexpr const char* _type_key = "tir.And";
451   TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode);
452 };
453 
454 /*!
455  * \brief Managed reference to AndNode
456  * \sa AndNode
457  */
458 class And : public PrimExpr {
459  public:
460   TVM_DLL And(PrimExpr a, PrimExpr b);
461   TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode);
462 };
463 
464 /*! \brief a || b */
465 class OrNode : public PrimExprNode {
466  public:
467   /*! \brief The left operand. */
468   PrimExpr a;
469   /*! \brief The right operand. */
470   PrimExpr b;
471 
VisitAttrs(AttrVisitor * v)472   void VisitAttrs(AttrVisitor* v) {
473     v->Visit("dtype", &dtype);
474     v->Visit("a", &a);
475     v->Visit("b", &b);
476   }
477 
SEqualReduce(const OrNode * other,SEqualReducer equal)478   bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
479     return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
480   }
481 
SHashReduce(SHashReducer hash_reduce)482   void SHashReduce(SHashReducer hash_reduce) const {
483     hash_reduce(dtype);
484     hash_reduce(a);
485     hash_reduce(b);
486   }
487 
488   static constexpr const char* _type_key = "tir.Or";
489   TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode);
490 };
491 
492 /*!
493  * \brief Managed reference to OrNode
494  * \sa OrNode
495  */
496 class Or : public PrimExpr {
497  public:
498   TVM_DLL Or(PrimExpr a, PrimExpr b);
499   TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode);
500 };
501 
502 /*! \brief !a */
503 class NotNode : public PrimExprNode {
504  public:
505   /*! \brief The input operand. */
506   PrimExpr a;
507 
VisitAttrs(AttrVisitor * v)508   void VisitAttrs(AttrVisitor* v) {
509     v->Visit("dtype", &dtype);
510     v->Visit("a", &a);
511   }
512 
SEqualReduce(const NotNode * other,SEqualReducer equal)513   bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
514     return equal(dtype, other->dtype) && equal(a, other->a);
515   }
516 
SHashReduce(SHashReducer hash_reduce)517   void SHashReduce(SHashReducer hash_reduce) const {
518     hash_reduce(dtype);
519     hash_reduce(a);
520   }
521 
522   static constexpr const char* _type_key = "tir.Not";
523   TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode);
524 };
525 
526 /*!
527  * \brief Managed reference to NotNode
528  * \sa NotNode
529  */
530 class Not : public PrimExpr {
531  public:
532   TVM_DLL Not(PrimExpr a);
533   TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode);
534 };
535 
536 /*!
537  * \brief return true_value if condition is true, otherwise return false_value.
538  * \note Both true_value and false_value could be evaluated
539  *       regardless of the condition value.
540  *       Do not use it to guard against out of bound access,
541  *       please use if_then_else instead.
542  */
543 class SelectNode : public PrimExprNode {
544  public:
545   /*! \brief The condition */
546   PrimExpr condition;
547   /*! \brief value to be returned when condition is true. */
548   PrimExpr true_value;
549   /*! \brief value to be returned when condition is false. */
550   PrimExpr false_value;
551 
VisitAttrs(AttrVisitor * v)552   void VisitAttrs(AttrVisitor* v) {
553     v->Visit("dtype", &dtype);
554     v->Visit("condition", &condition);
555     v->Visit("true_value", &true_value);
556     v->Visit("false_value", &false_value);
557   }
558 
SEqualReduce(const SelectNode * other,SEqualReducer equal)559   bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
560     return equal(dtype, other->dtype) && equal(condition, other->condition) &&
561            equal(true_value, other->true_value) && equal(false_value, other->false_value);
562   }
563 
SHashReduce(SHashReducer hash_reduce)564   void SHashReduce(SHashReducer hash_reduce) const {
565     hash_reduce(dtype);
566     hash_reduce(condition);
567     hash_reduce(true_value);
568     hash_reduce(false_value);
569   }
570 
571   static constexpr const char* _type_key = "tir.Select";
572   TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode);
573 };
574 
575 /*!
576  * \brief Managed reference to SelectNode
577  * \sa SelectNode
578  */
579 class Select : public PrimExpr {
580  public:
581   TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value);
582 
583   TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode);
584 };
585 
586 /*!
587  * \brief Load value from the high dimension buffer.
588  *
589  * \code
590  *
591  *  value = buffer[i, j];
592  *
593  * \endcode
594  * \sa BufferStore
595  */
596 class BufferLoadNode : public PrimExprNode {
597  public:
598   /*! \brief The buffer variable. */
599   Buffer buffer;
600   /*! \brief The indices location to be loaded. */
601   Array<PrimExpr> indices;
602 
VisitAttrs(AttrVisitor * v)603   void VisitAttrs(AttrVisitor* v) {
604     v->Visit("dtype", &(this->dtype));
605     v->Visit("buffer", &buffer);
606     v->Visit("indices", &indices);
607   }
608 
SEqualReduce(const BufferLoadNode * other,SEqualReducer equal)609   bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
610     return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
611            equal(indices, other->indices);
612   }
613 
SHashReduce(SHashReducer hash_reduce)614   void SHashReduce(SHashReducer hash_reduce) const {
615     hash_reduce(dtype);
616     hash_reduce(buffer);
617     hash_reduce(indices);
618   }
619 
620   static constexpr const char* _type_key = "tir.BufferLoad";
621   TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
622 };
623 
624 /*!
625  * \brief Managed reference to BufferLoadNode.
626  * \sa BufferLoadNode
627  */
628 class BufferLoad : public PrimExpr {
629  public:
630   TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices);
631   TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
632 };
633 
634 /*!
635  * \brief Load value from the result produced by the producer.
636  *
637  * \note This node only appears in high-level DSLs that are built on top of the TIR.
638  *       It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
639  *       this node before TIR transformations.
640  *
641  * \sa ProducerLoad, DataProducerNode
642  */
643 class ProducerLoadNode : public PrimExprNode {
644  public:
645   /*! \brief The buffer producer. */
646   DataProducer producer;
647   /*! \brief The location arguments. */
648   Array<PrimExpr> indices;
649 
VisitAttrs(AttrVisitor * v)650   void VisitAttrs(AttrVisitor* v) {
651     v->Visit("dtype", &(this->dtype));
652     v->Visit("producer", &producer);
653     v->Visit("indices", &indices);
654   }
655 
SEqualReduce(const ProducerLoadNode * other,SEqualReducer equal)656   bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
657     return equal(dtype, other->dtype) && equal(producer, other->producer) &&
658            equal(indices, other->indices);
659   }
660 
SHashReduce(SHashReducer hash_reduce)661   void SHashReduce(SHashReducer hash_reduce) const {
662     hash_reduce(dtype);
663     hash_reduce(producer);
664     hash_reduce(indices);
665   }
666 
667   static constexpr const char* _type_key = "tir.ProducerLoad";
668   TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
669 };
670 
671 /*!
672  * \brief Managed reference to ProducerLoadNode.
673  * \sa ProducerLoadNode
674  */
675 class ProducerLoad : public PrimExpr {
676  public:
677   TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices);
678 
679   TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
680 };
681 
682 /*!
683  * \brief Load the value from buffer_var.
684  *
685  *  Equivalent to ((DType*)buffer_var)[index]
686  *  where DType is the type specified by type().element_of().
687  *
688  *  For example, if type = float32x3, then the load will corresponds to
689  *
690  * \code
691  *
692  *  auto buffer = static_cast<float*>(buffer_var);
693  *  auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]);
694  *
695  * \endcode
696  */
697 class LoadNode : public PrimExprNode {
698  public:
699   /*! \brief The buffer variable. */
700   Var buffer_var;
701   /*! \brief The index locations to be loaded. */
702   PrimExpr index;
703   /*! \brief The predicate to mask which lanes would be loaded. */
704   PrimExpr predicate;
705 
VisitAttrs(AttrVisitor * v)706   void VisitAttrs(AttrVisitor* v) {
707     v->Visit("dtype", &dtype);
708     v->Visit("buffer_var", &buffer_var);
709     v->Visit("index", &index);
710     v->Visit("predicate", &predicate);
711   }
712 
SEqualReduce(const LoadNode * other,SEqualReducer equal)713   bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
714     return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) &&
715            equal(index, other->index) && equal(predicate, other->predicate);
716   }
717 
SHashReduce(SHashReducer hash_reduce)718   void SHashReduce(SHashReducer hash_reduce) const {
719     hash_reduce(dtype);
720     hash_reduce(buffer_var);
721     hash_reduce(index);
722     hash_reduce(predicate);
723   }
724 
725   static constexpr const char* _type_key = "tir.Load";
726   TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
727 };
728 
729 /*!
730  * \brief Managed reference to LoadNode
731  * \sa LoadNode
732  */
733 class Load : public PrimExpr {
734  public:
735   TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate);
736   TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode);
737 };
738 
739 /*!
740  * \brief Construct a vector with lanes elements
741  *        where its i-th element equals base + i * stride.
742  *  This is useful to construct a index for a continuous vector load.
743  *
744  *  Examples:
745  *  - ramp(0, 1, 3) = [0, 1, 2]
746  *  - ramp(1, 2, 4) = [1, 3, 5, 7]
747  */
748 class RampNode : public PrimExprNode {
749  public:
750   /*! \brief The base value. */
751   PrimExpr base;
752   /*! \brief The stride of each step. */
753   PrimExpr stride;
754   /*! \brief Total number of lanes. */
755   int lanes;
756 
VisitAttrs(AttrVisitor * v)757   void VisitAttrs(AttrVisitor* v) {
758     v->Visit("dtype", &dtype);
759     v->Visit("base", &base);
760     v->Visit("stride", &stride);
761     v->Visit("lanes", &lanes);
762   }
763 
SEqualReduce(const RampNode * other,SEqualReducer equal)764   bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
765     return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) &&
766            equal(lanes, other->lanes);
767   }
768 
SHashReduce(SHashReducer hash_reduce)769   void SHashReduce(SHashReducer hash_reduce) const {
770     hash_reduce(dtype);
771     hash_reduce(base);
772     hash_reduce(stride);
773     hash_reduce(lanes);
774   }
775 
776   static constexpr const char* _type_key = "tir.Ramp";
777   TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
778 };
779 
780 /*!
781  * \brief Managed reference to RampNode
782  * \sa RampNode
783  */
784 class Ramp : public PrimExpr {
785  public:
786   TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes);
787   TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
788 };
789 
790 /*! \brief Create a vector where all the elements are value. */
791 class BroadcastNode : public PrimExprNode {
792  public:
793   /*! \brief The base value. */
794   PrimExpr value;
795   /*! \brief The number of lanes. */
796   int lanes;
797 
VisitAttrs(AttrVisitor * v)798   void VisitAttrs(AttrVisitor* v) {
799     v->Visit("dtype", &dtype);
800     v->Visit("value", &value);
801     v->Visit("lanes", &lanes);
802   }
803 
SEqualReduce(const BroadcastNode * other,SEqualReducer equal)804   bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
805     return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes);
806   }
807 
SHashReduce(SHashReducer hash_reduce)808   void SHashReduce(SHashReducer hash_reduce) const {
809     hash_reduce(dtype);
810     hash_reduce(value);
811     hash_reduce(lanes);
812   }
813 
814   static constexpr const char* _type_key = "tir.Broadcast";
815   TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode);
816 };
817 
818 /*!
819  * \brief Managed reference to BroadcastNode
820  * \sa BroadcastNode
821  */
822 class Broadcast : public PrimExpr {
823  public:
824   TVM_DLL Broadcast(PrimExpr value, int lanes);
825   TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
826 };
827 
828 /*!
829  * \brief Let binding. Bind var to value then evaluate body.
830  */
831 class LetNode : public PrimExprNode {
832  public:
833   /*! \brief The variable. */
834   Var var;
835   /*! \brief The value to be binded. */
836   PrimExpr value;
837   /*! \brief The result expression. */
838   PrimExpr body;
839 
VisitAttrs(AttrVisitor * v)840   void VisitAttrs(AttrVisitor* v) {
841     v->Visit("dtype", &dtype);
842     v->Visit("var", &var);
843     v->Visit("value", &value);
844     v->Visit("body", &body);
845   }
846 
SEqualReduce(const LetNode * other,SEqualReducer equal)847   bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
848     return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) &&
849            equal(value, other->value) && equal(body, other->body);
850   }
851 
SHashReduce(SHashReducer hash_reduce)852   void SHashReduce(SHashReducer hash_reduce) const {
853     hash_reduce(dtype);
854     hash_reduce.DefHash(var);
855     hash_reduce(value);
856     hash_reduce(body);
857   }
858 
859   static constexpr const char* _type_key = "tir.Let";
860   TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode);
861 };
862 
863 /*!
864  * \brief Managed reference to LetNode
865  * \sa LetNode
866  */
867 class Let : public PrimExpr {
868  public:
869   TVM_DLL Let(Var var, PrimExpr value, PrimExpr body);
870   TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
871 };
872 
873 /*!
874  * \brief Call node.
875  */
876 class CallNode : public PrimExprNode {
877  public:
878   /*!
879    * \brief The operator(function) being invoked
880    *
881    *  - It can be tvm::Op which corresponds to the primitive operators(intrinsics).
882    *  - It can also be another function in the IRModule (GlobalVar).
883    */
884   RelayExpr op;
885 
886   /*! \brief The arguments. */
887   Array<PrimExpr> args;
VisitAttrs(AttrVisitor * v)888   void VisitAttrs(AttrVisitor* v) {
889     v->Visit("dtype", &dtype);
890     v->Visit("op", &op);
891     v->Visit("args", &args);
892   }
893 
SEqualReduce(const CallNode * other,SEqualReducer equal)894   bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
895     return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
896   }
897 
SHashReduce(SHashReducer hash_reduce)898   void SHashReduce(SHashReducer hash_reduce) const {
899     hash_reduce(dtype);
900     hash_reduce(op);
901     hash_reduce(args);
902   }
903 
904   static constexpr const char* _type_key = "tir.Call";
905   TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
906 };
907 
908 /*!
909  * \brief Managed reference to CallNode
910  * \sa CallNode
911  */
912 class Call : public PrimExpr {
913  public:
914   TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args);
915   TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
916 };
917 
918 /*!
919  * \brief Shuffle instruction.
920  *  vec = concat(vectors)
921  *  result = (vec[indices[0]], vec[indices[1]] ...)
922  */
923 class ShuffleNode : public PrimExprNode {
924  public:
925   /*! \brief the input vectors. */
926   Array<PrimExpr> vectors;
927   /*! \brief The indices of each element. */
928   Array<PrimExpr> indices;
929 
VisitAttrs(AttrVisitor * v)930   void VisitAttrs(AttrVisitor* v) {
931     v->Visit("vectors", &vectors);
932     v->Visit("indices", &indices);
933   }
934 
SEqualReduce(const ShuffleNode * other,SEqualReducer equal)935   bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
936     return equal(dtype, other->dtype) && equal(vectors, other->vectors) &&
937            equal(indices, other->indices);
938   }
939 
SHashReduce(SHashReducer hash_reduce)940   void SHashReduce(SHashReducer hash_reduce) const {
941     hash_reduce(dtype);
942     hash_reduce(vectors);
943     hash_reduce(indices);
944   }
945 
946   static constexpr const char* _type_key = "tir.Shuffle";
947   TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode);
948 };
949 
950 /*!
951  * \brief Managed reference to ShuffleNode
952  * \sa ShuffleNode
953  */
954 class Shuffle : public PrimExpr {
955  public:
956   TVM_DLL Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices);
957   TVM_DLL static PrimExpr Concat(Array<PrimExpr> vectors);
958   TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index);
959 
960   TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode);
961 };
962 
963 // Reduce operator
964 /*!
965  * \brief A commutative reducer node to represent a commutative
966  *  binary operator with identity element
967  */
968 class CommReducerNode : public Object {
969  public:
970   /*! \brief The left argument of reducer */
971   Array<Var> lhs;
972   /*! \brief The right argument of reducer */
973   Array<Var> rhs;
974   /*! \brief The result of reducer */
975   Array<PrimExpr> result;
976   /*!
977    * \brief The identity element of reducer, which leaves other
978    *  elements unchanged when combined with it, with respect to
979    *  the binary operation of this reducer uses.
980    */
981   Array<PrimExpr> identity_element;
982   /*! \brief Function call operator to combine a and b */
983   Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const;
984 
VisitAttrs(AttrVisitor * v)985   void VisitAttrs(AttrVisitor* v) {
986     v->Visit("lhs", &lhs);
987     v->Visit("rhs", &rhs);
988     v->Visit("result", &result);
989     v->Visit("identity_element", &identity_element);
990   }
991 
SEqualReduce(const CommReducerNode * other,SEqualReducer equal)992   bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
993     return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) &&
994            equal(result, other->result) && equal(identity_element, other->identity_element);
995   }
996 
SHashReduce(SHashReducer hash_reduce)997   void SHashReduce(SHashReducer hash_reduce) const {
998     hash_reduce.DefHash(lhs);
999     hash_reduce.DefHash(rhs);
1000     hash_reduce(result);
1001     hash_reduce(identity_element);
1002   }
1003 
1004   static constexpr const char* _type_key = "tir.CommReducer";
1005   static constexpr const bool _type_has_method_sequal_reduce = true;
1006   static constexpr const bool _type_has_method_shash_reduce = true;
1007   TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
1008 };
1009 
1010 /*!
1011  * \brief Managed reference to CommReducerNode
1012  * \sa CommReducerNode
1013  */
1014 class CommReducer : public ObjectRef {
1015  public:
1016   TVM_DLL CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
1017                       Array<PrimExpr> identity_element);
1018 
1019   TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode);
1020 };
1021 
1022 /*! \brief Reduction operator operator */
1023 class ReduceNode : public PrimExprNode {
1024  public:
1025   /*! \brief The commutative combiner */
1026   CommReducer combiner;
1027   /*! \brief The source operand */
1028   Array<PrimExpr> source;
1029   /*! \brief The init operand */
1030   Array<PrimExpr> init;
1031   /*! \brief The reduction axis */
1032   Array<IterVar> axis;
1033   /*!
1034    * \brief Predicate on the reduction
1035    *  Only add the body to reduction if condition is true.
1036    */
1037   PrimExpr condition;
1038   /*! \brief the index of this reduce node */
1039   int value_index;
1040 
VisitAttrs(AttrVisitor * v)1041   void VisitAttrs(AttrVisitor* v) {
1042     v->Visit("dtype", &dtype);
1043     v->Visit("combiner", &combiner);
1044     v->Visit("source", &source);
1045     v->Visit("init", &init);
1046     v->Visit("axis", &axis);
1047     v->Visit("condition", &condition);
1048     v->Visit("value_index", &value_index);
1049   }
1050 
SEqualReduce(const ReduceNode * other,SEqualReducer equal)1051   bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
1052     // check axis first so IterVars can define the necessary variables.
1053     return equal(dtype, other->dtype) && equal(axis, other->axis) &&
1054            equal(combiner, other->combiner) && equal(source, other->source) &&
1055            equal(init, other->init) && equal(condition, other->condition) &&
1056            equal(value_index, other->value_index);
1057   }
1058 
SHashReduce(SHashReducer hash_reduce)1059   void SHashReduce(SHashReducer hash_reduce) const {
1060     hash_reduce(dtype);
1061     hash_reduce(axis);
1062     hash_reduce(combiner);
1063     hash_reduce(source);
1064     hash_reduce(init);
1065     hash_reduce(condition);
1066     hash_reduce(value_index);
1067   }
1068 
1069   static constexpr const char* _type_key = "tir.Reduce";
1070   TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
1071 };
1072 
1073 /*!
1074  * \brief Managed reference to ReduceNode
1075  * \sa ReduceNode
1076  */
1077 class Reduce : public PrimExpr {
1078  public:
1079   TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition,
1080                  int value_index, Array<PrimExpr> init);
1081 
1082   TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
1083 };
1084 
1085 /*! \brief Any shape. */
1086 class AnyNode : public PrimExprNode {
1087  public:
VisitAttrs(AttrVisitor * v)1088   void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); }
1089 
SEqualReduce(const AnyNode * other,SEqualReducer equal)1090   bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
1091     return equal(dtype, other->dtype);
1092   }
1093 
SHashReduce(SHashReducer hash_reduce)1094   void SHashReduce(SHashReducer hash_reduce) const {}
1095 
1096   /*! \brief Convert to var. */
ToVar()1097   Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
1098 
1099   static constexpr const char* _type_key = "tir.Any";
1100   TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
1101 };
1102 
1103 /*!
1104  * \brief Managed reference to AnyNode
1105  * \sa AnyNode
1106  */
1107 class Any : public PrimExpr {
1108  public:
1109   TVM_DLL Any();
1110 
1111   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
1112 };
1113 
1114 /*
1115  * \brief Template function to convert Map to unordered_map
1116  *  Sometimes useful for API gluing when internal uses unordered_map
1117  * \param dmap The container map
1118  * \return The corresponding unordered_map.
1119  * \tparam K the key of the Map.
1120  * \tparam V the value of the Map.
1121  */
1122 template <typename K, typename V>
as_unordered_map(const Map<K,V> & dmap)1123 inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
1124   std::unordered_map<K, V> ret;
1125   for (auto kv : dmap) {
1126     ret[kv.first] = kv.second;
1127   }
1128   return ret;
1129 }
1130 }  // namespace tir
1131 }  // namespace tvm
1132 
1133 namespace std {
1134 template <>
1135 struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};
1136 }  // namespace std
1137 #endif  // TVM_TIR_EXPR_H_
1138