1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/tir/expr.h
22 * \brief TIR expressions.
23 */
24 // Acknowledgement: Many low-level IR nodes originate from Halide.
25 #ifndef TVM_TIR_EXPR_H_
26 #define TVM_TIR_EXPR_H_
27
28 #include <tvm/ir/expr.h>
29 #include <tvm/node/container.h>
30 #include <tvm/node/functor.h>
31 #include <tvm/node/node.h>
32 #include <tvm/runtime/c_runtime_api.h>
33 #include <tvm/runtime/data_type.h>
34 #include <tvm/tir/buffer.h>
35 #include <tvm/tir/var.h>
36
37 #include <algorithm>
38 #include <iostream>
39 #include <limits>
40 #include <string>
41 #include <unordered_map>
42 #include <utility>
43
44 namespace tvm {
45 namespace tir {
46
47 using IntImmNode = tvm::IntImmNode;
48 using FloatImmNode = tvm::FloatImmNode;
49
50 /*! \brief String constants, only used in asserts. */
51 class StringImmNode : public PrimExprNode {
52 public:
53 /*! \brief The constant value content. */
54 String value;
55
VisitAttrs(AttrVisitor * v)56 void VisitAttrs(AttrVisitor* v) {
57 v->Visit("dtype", &dtype);
58 v->Visit("value", &value);
59 }
60
SEqualReduce(const StringImmNode * other,SEqualReducer equal)61 bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
62 return equal(value, other->value);
63 }
64
SHashReduce(SHashReducer hash_reduce)65 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
66
67 static constexpr const char* _type_key = "tir.StringImm";
68 TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode);
69 };
70
71 /*!
72 * \brief Managed reference to StringImmNode.
73 * \sa StringImmNode
74 */
75 class StringImm : public PrimExpr {
76 public:
77 TVM_DLL StringImm(String value);
78 TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode);
79 };
80
81 /*!
82 * \brief Cast value from one data type to another.
83 * \note The lanes of value should keep fixed.
84 */
85 class CastNode : public PrimExprNode {
86 public:
87 /*! \brief Original data type. */
88 PrimExpr value;
89
VisitAttrs(AttrVisitor * v)90 void VisitAttrs(AttrVisitor* v) {
91 v->Visit("dtype", &dtype);
92 v->Visit("value", &value);
93 }
94
SEqualReduce(const CastNode * other,SEqualReducer equal)95 bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
96 return equal(dtype, other->dtype) && equal(value, other->value);
97 }
98
SHashReduce(SHashReducer hash_reduce)99 void SHashReduce(SHashReducer hash_reduce) const {
100 hash_reduce(dtype);
101 hash_reduce(value);
102 }
103
104 static constexpr const char* _type_key = "tir.Cast";
105 TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode);
106 };
107
108 /*!
109 * \brief Managed reference to CastNode
110 * \sa CastNode
111 */
112 class Cast : public PrimExpr {
113 public:
114 TVM_DLL Cast(DataType dtype, PrimExpr value);
115 TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode);
116 };
117
118 /*!
119 * \brief Base template to implement binary ops.
120 * \tparam T The type of the child class.
121 */
122 template <typename T>
123 class BinaryOpNode : public PrimExprNode {
124 public:
125 /*! \brief The left operand. */
126 PrimExpr a;
127 /*! \brief The right operand. */
128 PrimExpr b;
129
VisitAttrs(AttrVisitor * v)130 void VisitAttrs(AttrVisitor* v) {
131 v->Visit("dtype", &(this->dtype));
132 v->Visit("a", &a);
133 v->Visit("b", &b);
134 }
135
SEqualReduce(const T * other,SEqualReducer equal)136 bool SEqualReduce(const T* other, SEqualReducer equal) const {
137 return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
138 }
139
SHashReduce(SHashReducer hash_reduce)140 void SHashReduce(SHashReducer hash_reduce) const {
141 hash_reduce(dtype);
142 hash_reduce(a);
143 hash_reduce(b);
144 }
145
146 TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
147 };
148
149 /*! \brief a + b */
150 class AddNode : public BinaryOpNode<AddNode> {
151 public:
152 static constexpr const char* _type_key = "tir.Add";
153 };
154
155 /*!
156 * \brief Managed reference to AddNode
157 * \sa AddNode
158 */
159 class Add : public PrimExpr {
160 public:
161 TVM_DLL Add(PrimExpr a, PrimExpr b);
162 TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode);
163 };
164
165 /*! \brief a - b */
166 class SubNode : public BinaryOpNode<SubNode> {
167 public:
168 static constexpr const char* _type_key = "tir.Sub";
169 };
170
171 /*!
172 * \brief Managed reference to SubNode
173 * \sa SubNode
174 */
175 class Sub : public PrimExpr {
176 public:
177 TVM_DLL Sub(PrimExpr a, PrimExpr b);
178 TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode);
179 };
180
181 /*! \brief a * b */
182 class MulNode : public BinaryOpNode<MulNode> {
183 public:
184 static constexpr const char* _type_key = "tir.Mul";
185 };
186
187 /*!
188 * \brief Managed reference to MulNode
189 * \sa MulNode
190 */
191 class Mul : public PrimExpr {
192 public:
193 TVM_DLL Mul(PrimExpr a, PrimExpr b);
194 TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode);
195 };
196
197 /*!
198 * \brief a / b in the C semnatics.
199 * \note For integer division, C standard uses trunc div.
200 */
201 class DivNode : public BinaryOpNode<DivNode> {
202 public:
203 static constexpr const char* _type_key = "tir.Div";
204 };
205
206 /*!
207 * \brief Managed reference to DivNode
208 * \sa DivNode
209 */
210 class Div : public PrimExpr {
211 public:
212 TVM_DLL Div(PrimExpr a, PrimExpr b);
213 TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode);
214 };
215
216 /*!
217 * \brief a % b in the C semnatics.
218 * \note For integer division, C standard uses trunc div.
219 */
220 class ModNode : public BinaryOpNode<ModNode> {
221 public:
222 static constexpr const char* _type_key = "tir.Mod";
223 };
224
225 /*!
226 * \brief Managed reference to ModNode
227 * \sa ModNode
228 */
229 class Mod : public PrimExpr {
230 public:
231 TVM_DLL Mod(PrimExpr a, PrimExpr b);
232 TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode);
233 };
234
235 /*! \brief Floor division, floor(a/b) */
236 class FloorDivNode : public BinaryOpNode<FloorDivNode> {
237 public:
238 static constexpr const char* _type_key = "tir.FloorDiv";
239 };
240
241 /*!
242 * \brief Managed reference to FloorDivNode
243 * \sa FloorDivNode
244 */
245 class FloorDiv : public PrimExpr {
246 public:
247 TVM_DLL FloorDiv(PrimExpr a, PrimExpr b);
248 TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode);
249 };
250
251 /*! \brief The remainder of the floordiv */
252 class FloorModNode : public BinaryOpNode<FloorModNode> {
253 public:
254 static constexpr const char* _type_key = "tir.FloorMod";
255 };
256
257 /*!
258 * \brief Managed reference to FloorModNode
259 * \sa FloorModNode
260 */
261 class FloorMod : public PrimExpr {
262 public:
263 TVM_DLL FloorMod(PrimExpr a, PrimExpr b);
264 TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode);
265 };
266
267 /*! \brief min(a, b) */
268 class MinNode : public BinaryOpNode<MinNode> {
269 public:
270 static constexpr const char* _type_key = "tir.Min";
271 };
272
273 /*!
274 * \brief Managed reference to MinNode
275 * \sa MinNode
276 */
277 class Min : public PrimExpr {
278 public:
279 TVM_DLL Min(PrimExpr a, PrimExpr b);
280 TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode);
281 };
282
283 /*! \brief max(a, b) */
284 class MaxNode : public BinaryOpNode<MaxNode> {
285 public:
286 static constexpr const char* _type_key = "tir.Max";
287 };
288
289 /*!
290 * \brief Managed reference to MaxNode
291 * \sa MaxNode
292 */
293 class Max : public PrimExpr {
294 public:
295 TVM_DLL Max(PrimExpr a, PrimExpr b);
296 TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode);
297 };
298
299 /*!
300 * \brief Base template to implement comparison ops.
301 * \tparam T The type of the child class.
302 */
303 template <typename T>
304 class CmpOpNode : public PrimExprNode {
305 public:
306 /*! \brief The left operand. */
307 PrimExpr a;
308 /*! \brief The right operand. */
309 PrimExpr b;
310
VisitAttrs(AttrVisitor * v)311 void VisitAttrs(AttrVisitor* v) {
312 v->Visit("dtype", &(this->dtype));
313 v->Visit("a", &a);
314 v->Visit("b", &b);
315 }
316
SEqualReduce(const T * other,SEqualReducer equal)317 bool SEqualReduce(const T* other, SEqualReducer equal) const {
318 return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
319 }
320
SHashReduce(SHashReducer hash_reduce)321 void SHashReduce(SHashReducer hash_reduce) const {
322 hash_reduce(dtype);
323 hash_reduce(a);
324 hash_reduce(b);
325 }
326
327 TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
328 };
329
330 /*! \brief a == b */
331 class EQNode : public CmpOpNode<EQNode> {
332 public:
333 static constexpr const char* _type_key = "tir.EQ";
334 };
335
336 /*!
337 * \brief Managed reference to EQNode
338 * \sa EQNode
339 */
340 class EQ : public PrimExpr {
341 public:
342 TVM_DLL EQ(PrimExpr a, PrimExpr b);
343 TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode);
344 };
345
346 /*! \brief a != b */
347 class NENode : public CmpOpNode<NENode> {
348 public:
349 static constexpr const char* _type_key = "tir.NE";
350 };
351
352 /*!
353 * \brief Managed reference to NENode
354 * \sa NENode
355 */
356 class NE : public PrimExpr {
357 public:
358 TVM_DLL NE(PrimExpr a, PrimExpr b);
359 TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode);
360 };
361
362 /*! \brief a < b */
363 class LTNode : public CmpOpNode<LTNode> {
364 public:
365 static constexpr const char* _type_key = "tir.LT";
366 };
367
368 /*!
369 * \brief Managed reference to LTNode
370 * \sa LTNode
371 */
372 class LT : public PrimExpr {
373 public:
374 TVM_DLL LT(PrimExpr a, PrimExpr b);
375 TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode);
376 };
377
378 /*! \brief a <= b */
379 struct LENode : public CmpOpNode<LENode> {
380 public:
381 static constexpr const char* _type_key = "tir.LE";
382 };
383
384 /*!
385 * \brief Managed reference to LENode
386 * \sa LENode
387 */
388 class LE : public PrimExpr {
389 public:
390 TVM_DLL LE(PrimExpr a, PrimExpr b);
391 TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode);
392 };
393
394 /*! \brief a > b */
395 class GTNode : public CmpOpNode<GTNode> {
396 public:
397 static constexpr const char* _type_key = "tir.GT";
398 };
399
400 /*!
401 * \brief Managed reference to GTNode
402 * \sa GTNode
403 */
404 class GT : public PrimExpr {
405 public:
406 TVM_DLL GT(PrimExpr a, PrimExpr b);
407 TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode);
408 };
409
410 /*! \brief a >= b */
411 class GENode : public CmpOpNode<GENode> {
412 public:
413 static constexpr const char* _type_key = "tir.GE";
414 };
415
416 /*!
417 * \brief Managed reference to GENode
418 * \sa GENode
419 */
420 class GE : public PrimExpr {
421 public:
422 TVM_DLL GE(PrimExpr a, PrimExpr b);
423 TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode);
424 };
425
426 /*! \brief a && b */
427 class AndNode : public PrimExprNode {
428 public:
429 /*! \brief The left operand. */
430 PrimExpr a;
431 /*! \brief The right operand. */
432 PrimExpr b;
433
VisitAttrs(AttrVisitor * v)434 void VisitAttrs(AttrVisitor* v) {
435 v->Visit("dtype", &(this->dtype));
436 v->Visit("a", &a);
437 v->Visit("b", &b);
438 }
439
SEqualReduce(const AndNode * other,SEqualReducer equal)440 bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
441 return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
442 }
443
SHashReduce(SHashReducer hash_reduce)444 void SHashReduce(SHashReducer hash_reduce) const {
445 hash_reduce(dtype);
446 hash_reduce(a);
447 hash_reduce(b);
448 }
449
450 static constexpr const char* _type_key = "tir.And";
451 TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode);
452 };
453
454 /*!
455 * \brief Managed reference to AndNode
456 * \sa AndNode
457 */
458 class And : public PrimExpr {
459 public:
460 TVM_DLL And(PrimExpr a, PrimExpr b);
461 TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode);
462 };
463
464 /*! \brief a || b */
465 class OrNode : public PrimExprNode {
466 public:
467 /*! \brief The left operand. */
468 PrimExpr a;
469 /*! \brief The right operand. */
470 PrimExpr b;
471
VisitAttrs(AttrVisitor * v)472 void VisitAttrs(AttrVisitor* v) {
473 v->Visit("dtype", &dtype);
474 v->Visit("a", &a);
475 v->Visit("b", &b);
476 }
477
SEqualReduce(const OrNode * other,SEqualReducer equal)478 bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
479 return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
480 }
481
SHashReduce(SHashReducer hash_reduce)482 void SHashReduce(SHashReducer hash_reduce) const {
483 hash_reduce(dtype);
484 hash_reduce(a);
485 hash_reduce(b);
486 }
487
488 static constexpr const char* _type_key = "tir.Or";
489 TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode);
490 };
491
492 /*!
493 * \brief Managed reference to OrNode
494 * \sa OrNode
495 */
496 class Or : public PrimExpr {
497 public:
498 TVM_DLL Or(PrimExpr a, PrimExpr b);
499 TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode);
500 };
501
502 /*! \brief !a */
503 class NotNode : public PrimExprNode {
504 public:
505 /*! \brief The input operand. */
506 PrimExpr a;
507
VisitAttrs(AttrVisitor * v)508 void VisitAttrs(AttrVisitor* v) {
509 v->Visit("dtype", &dtype);
510 v->Visit("a", &a);
511 }
512
SEqualReduce(const NotNode * other,SEqualReducer equal)513 bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
514 return equal(dtype, other->dtype) && equal(a, other->a);
515 }
516
SHashReduce(SHashReducer hash_reduce)517 void SHashReduce(SHashReducer hash_reduce) const {
518 hash_reduce(dtype);
519 hash_reduce(a);
520 }
521
522 static constexpr const char* _type_key = "tir.Not";
523 TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode);
524 };
525
526 /*!
527 * \brief Managed reference to NotNode
528 * \sa NotNode
529 */
530 class Not : public PrimExpr {
531 public:
532 TVM_DLL Not(PrimExpr a);
533 TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode);
534 };
535
536 /*!
537 * \brief return true_value if condition is true, otherwise return false_value.
538 * \note Both true_value and false_value could be evaluated
539 * regardless of the condition value.
540 * Do not use it to guard against out of bound access,
541 * please use if_then_else instead.
542 */
543 class SelectNode : public PrimExprNode {
544 public:
545 /*! \brief The condition */
546 PrimExpr condition;
547 /*! \brief value to be returned when condition is true. */
548 PrimExpr true_value;
549 /*! \brief value to be returned when condition is false. */
550 PrimExpr false_value;
551
VisitAttrs(AttrVisitor * v)552 void VisitAttrs(AttrVisitor* v) {
553 v->Visit("dtype", &dtype);
554 v->Visit("condition", &condition);
555 v->Visit("true_value", &true_value);
556 v->Visit("false_value", &false_value);
557 }
558
SEqualReduce(const SelectNode * other,SEqualReducer equal)559 bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
560 return equal(dtype, other->dtype) && equal(condition, other->condition) &&
561 equal(true_value, other->true_value) && equal(false_value, other->false_value);
562 }
563
SHashReduce(SHashReducer hash_reduce)564 void SHashReduce(SHashReducer hash_reduce) const {
565 hash_reduce(dtype);
566 hash_reduce(condition);
567 hash_reduce(true_value);
568 hash_reduce(false_value);
569 }
570
571 static constexpr const char* _type_key = "tir.Select";
572 TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode);
573 };
574
575 /*!
576 * \brief Managed reference to SelectNode
577 * \sa SelectNode
578 */
579 class Select : public PrimExpr {
580 public:
581 TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value);
582
583 TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode);
584 };
585
586 /*!
587 * \brief Load value from the high dimension buffer.
588 *
589 * \code
590 *
591 * value = buffer[i, j];
592 *
593 * \endcode
594 * \sa BufferStore
595 */
596 class BufferLoadNode : public PrimExprNode {
597 public:
598 /*! \brief The buffer variable. */
599 Buffer buffer;
600 /*! \brief The indices location to be loaded. */
601 Array<PrimExpr> indices;
602
VisitAttrs(AttrVisitor * v)603 void VisitAttrs(AttrVisitor* v) {
604 v->Visit("dtype", &(this->dtype));
605 v->Visit("buffer", &buffer);
606 v->Visit("indices", &indices);
607 }
608
SEqualReduce(const BufferLoadNode * other,SEqualReducer equal)609 bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const {
610 return equal(dtype, other->dtype) && equal(buffer, other->buffer) &&
611 equal(indices, other->indices);
612 }
613
SHashReduce(SHashReducer hash_reduce)614 void SHashReduce(SHashReducer hash_reduce) const {
615 hash_reduce(dtype);
616 hash_reduce(buffer);
617 hash_reduce(indices);
618 }
619
620 static constexpr const char* _type_key = "tir.BufferLoad";
621 TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
622 };
623
624 /*!
625 * \brief Managed reference to BufferLoadNode.
626 * \sa BufferLoadNode
627 */
628 class BufferLoad : public PrimExpr {
629 public:
630 TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices);
631 TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
632 };
633
634 /*!
635 * \brief Load value from the result produced by the producer.
636 *
637 * \note This node only appears in high-level DSLs that are built on top of the TIR.
638 * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower
639 * this node before TIR transformations.
640 *
641 * \sa ProducerLoad, DataProducerNode
642 */
643 class ProducerLoadNode : public PrimExprNode {
644 public:
645 /*! \brief The buffer producer. */
646 DataProducer producer;
647 /*! \brief The location arguments. */
648 Array<PrimExpr> indices;
649
VisitAttrs(AttrVisitor * v)650 void VisitAttrs(AttrVisitor* v) {
651 v->Visit("dtype", &(this->dtype));
652 v->Visit("producer", &producer);
653 v->Visit("indices", &indices);
654 }
655
SEqualReduce(const ProducerLoadNode * other,SEqualReducer equal)656 bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const {
657 return equal(dtype, other->dtype) && equal(producer, other->producer) &&
658 equal(indices, other->indices);
659 }
660
SHashReduce(SHashReducer hash_reduce)661 void SHashReduce(SHashReducer hash_reduce) const {
662 hash_reduce(dtype);
663 hash_reduce(producer);
664 hash_reduce(indices);
665 }
666
667 static constexpr const char* _type_key = "tir.ProducerLoad";
668 TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
669 };
670
671 /*!
672 * \brief Managed reference to ProducerLoadNode.
673 * \sa ProducerLoadNode
674 */
675 class ProducerLoad : public PrimExpr {
676 public:
677 TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices);
678
679 TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
680 };
681
682 /*!
683 * \brief Load the value from buffer_var.
684 *
685 * Equivalent to ((DType*)buffer_var)[index]
686 * where DType is the type specified by type().element_of().
687 *
688 * For example, if type = float32x3, then the load will corresponds to
689 *
690 * \code
691 *
692 * auto buffer = static_cast<float*>(buffer_var);
693 * auto loaded_val = float32x3(buffer[index.v0], buffer[index.v1], buffer[index.v2]);
694 *
695 * \endcode
696 */
697 class LoadNode : public PrimExprNode {
698 public:
699 /*! \brief The buffer variable. */
700 Var buffer_var;
701 /*! \brief The index locations to be loaded. */
702 PrimExpr index;
703 /*! \brief The predicate to mask which lanes would be loaded. */
704 PrimExpr predicate;
705
VisitAttrs(AttrVisitor * v)706 void VisitAttrs(AttrVisitor* v) {
707 v->Visit("dtype", &dtype);
708 v->Visit("buffer_var", &buffer_var);
709 v->Visit("index", &index);
710 v->Visit("predicate", &predicate);
711 }
712
SEqualReduce(const LoadNode * other,SEqualReducer equal)713 bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
714 return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) &&
715 equal(index, other->index) && equal(predicate, other->predicate);
716 }
717
SHashReduce(SHashReducer hash_reduce)718 void SHashReduce(SHashReducer hash_reduce) const {
719 hash_reduce(dtype);
720 hash_reduce(buffer_var);
721 hash_reduce(index);
722 hash_reduce(predicate);
723 }
724
725 static constexpr const char* _type_key = "tir.Load";
726 TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
727 };
728
729 /*!
730 * \brief Managed reference to LoadNode
731 * \sa LoadNode
732 */
733 class Load : public PrimExpr {
734 public:
735 TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate);
736 TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode);
737 };
738
739 /*!
740 * \brief Construct a vector with lanes elements
741 * where its i-th element equals base + i * stride.
742 * This is useful to construct a index for a continuous vector load.
743 *
744 * Examples:
745 * - ramp(0, 1, 3) = [0, 1, 2]
746 * - ramp(1, 2, 4) = [1, 3, 5, 7]
747 */
748 class RampNode : public PrimExprNode {
749 public:
750 /*! \brief The base value. */
751 PrimExpr base;
752 /*! \brief The stride of each step. */
753 PrimExpr stride;
754 /*! \brief Total number of lanes. */
755 int lanes;
756
VisitAttrs(AttrVisitor * v)757 void VisitAttrs(AttrVisitor* v) {
758 v->Visit("dtype", &dtype);
759 v->Visit("base", &base);
760 v->Visit("stride", &stride);
761 v->Visit("lanes", &lanes);
762 }
763
SEqualReduce(const RampNode * other,SEqualReducer equal)764 bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
765 return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) &&
766 equal(lanes, other->lanes);
767 }
768
SHashReduce(SHashReducer hash_reduce)769 void SHashReduce(SHashReducer hash_reduce) const {
770 hash_reduce(dtype);
771 hash_reduce(base);
772 hash_reduce(stride);
773 hash_reduce(lanes);
774 }
775
776 static constexpr const char* _type_key = "tir.Ramp";
777 TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
778 };
779
780 /*!
781 * \brief Managed reference to RampNode
782 * \sa RampNode
783 */
784 class Ramp : public PrimExpr {
785 public:
786 TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes);
787 TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
788 };
789
790 /*! \brief Create a vector where all the elements are value. */
791 class BroadcastNode : public PrimExprNode {
792 public:
793 /*! \brief The base value. */
794 PrimExpr value;
795 /*! \brief The number of lanes. */
796 int lanes;
797
VisitAttrs(AttrVisitor * v)798 void VisitAttrs(AttrVisitor* v) {
799 v->Visit("dtype", &dtype);
800 v->Visit("value", &value);
801 v->Visit("lanes", &lanes);
802 }
803
SEqualReduce(const BroadcastNode * other,SEqualReducer equal)804 bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
805 return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes);
806 }
807
SHashReduce(SHashReducer hash_reduce)808 void SHashReduce(SHashReducer hash_reduce) const {
809 hash_reduce(dtype);
810 hash_reduce(value);
811 hash_reduce(lanes);
812 }
813
814 static constexpr const char* _type_key = "tir.Broadcast";
815 TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode);
816 };
817
818 /*!
819 * \brief Managed reference to BroadcastNode
820 * \sa BroadcastNode
821 */
822 class Broadcast : public PrimExpr {
823 public:
824 TVM_DLL Broadcast(PrimExpr value, int lanes);
825 TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
826 };
827
828 /*!
829 * \brief Let binding. Bind var to value then evaluate body.
830 */
831 class LetNode : public PrimExprNode {
832 public:
833 /*! \brief The variable. */
834 Var var;
835 /*! \brief The value to be binded. */
836 PrimExpr value;
837 /*! \brief The result expression. */
838 PrimExpr body;
839
VisitAttrs(AttrVisitor * v)840 void VisitAttrs(AttrVisitor* v) {
841 v->Visit("dtype", &dtype);
842 v->Visit("var", &var);
843 v->Visit("value", &value);
844 v->Visit("body", &body);
845 }
846
SEqualReduce(const LetNode * other,SEqualReducer equal)847 bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
848 return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) &&
849 equal(value, other->value) && equal(body, other->body);
850 }
851
SHashReduce(SHashReducer hash_reduce)852 void SHashReduce(SHashReducer hash_reduce) const {
853 hash_reduce(dtype);
854 hash_reduce.DefHash(var);
855 hash_reduce(value);
856 hash_reduce(body);
857 }
858
859 static constexpr const char* _type_key = "tir.Let";
860 TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode);
861 };
862
863 /*!
864 * \brief Managed reference to LetNode
865 * \sa LetNode
866 */
867 class Let : public PrimExpr {
868 public:
869 TVM_DLL Let(Var var, PrimExpr value, PrimExpr body);
870 TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
871 };
872
873 /*!
874 * \brief Call node.
875 */
876 class CallNode : public PrimExprNode {
877 public:
878 /*!
879 * \brief The operator(function) being invoked
880 *
881 * - It can be tvm::Op which corresponds to the primitive operators(intrinsics).
882 * - It can also be another function in the IRModule (GlobalVar).
883 */
884 RelayExpr op;
885
886 /*! \brief The arguments. */
887 Array<PrimExpr> args;
VisitAttrs(AttrVisitor * v)888 void VisitAttrs(AttrVisitor* v) {
889 v->Visit("dtype", &dtype);
890 v->Visit("op", &op);
891 v->Visit("args", &args);
892 }
893
SEqualReduce(const CallNode * other,SEqualReducer equal)894 bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
895 return equal(dtype, other->dtype) && equal(op, other->op) && equal(args, other->args);
896 }
897
SHashReduce(SHashReducer hash_reduce)898 void SHashReduce(SHashReducer hash_reduce) const {
899 hash_reduce(dtype);
900 hash_reduce(op);
901 hash_reduce(args);
902 }
903
904 static constexpr const char* _type_key = "tir.Call";
905 TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
906 };
907
908 /*!
909 * \brief Managed reference to CallNode
910 * \sa CallNode
911 */
912 class Call : public PrimExpr {
913 public:
914 TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args);
915 TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
916 };
917
918 /*!
919 * \brief Shuffle instruction.
920 * vec = concat(vectors)
921 * result = (vec[indices[0]], vec[indices[1]] ...)
922 */
923 class ShuffleNode : public PrimExprNode {
924 public:
925 /*! \brief the input vectors. */
926 Array<PrimExpr> vectors;
927 /*! \brief The indices of each element. */
928 Array<PrimExpr> indices;
929
VisitAttrs(AttrVisitor * v)930 void VisitAttrs(AttrVisitor* v) {
931 v->Visit("vectors", &vectors);
932 v->Visit("indices", &indices);
933 }
934
SEqualReduce(const ShuffleNode * other,SEqualReducer equal)935 bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
936 return equal(dtype, other->dtype) && equal(vectors, other->vectors) &&
937 equal(indices, other->indices);
938 }
939
SHashReduce(SHashReducer hash_reduce)940 void SHashReduce(SHashReducer hash_reduce) const {
941 hash_reduce(dtype);
942 hash_reduce(vectors);
943 hash_reduce(indices);
944 }
945
946 static constexpr const char* _type_key = "tir.Shuffle";
947 TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode);
948 };
949
950 /*!
951 * \brief Managed reference to ShuffleNode
952 * \sa ShuffleNode
953 */
954 class Shuffle : public PrimExpr {
955 public:
956 TVM_DLL Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices);
957 TVM_DLL static PrimExpr Concat(Array<PrimExpr> vectors);
958 TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index);
959
960 TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode);
961 };
962
963 // Reduce operator
964 /*!
965 * \brief A commutative reducer node to represent a commutative
966 * binary operator with identity element
967 */
968 class CommReducerNode : public Object {
969 public:
970 /*! \brief The left argument of reducer */
971 Array<Var> lhs;
972 /*! \brief The right argument of reducer */
973 Array<Var> rhs;
974 /*! \brief The result of reducer */
975 Array<PrimExpr> result;
976 /*!
977 * \brief The identity element of reducer, which leaves other
978 * elements unchanged when combined with it, with respect to
979 * the binary operation of this reducer uses.
980 */
981 Array<PrimExpr> identity_element;
982 /*! \brief Function call operator to combine a and b */
983 Array<PrimExpr> operator()(Array<PrimExpr> a, Array<PrimExpr> b) const;
984
VisitAttrs(AttrVisitor * v)985 void VisitAttrs(AttrVisitor* v) {
986 v->Visit("lhs", &lhs);
987 v->Visit("rhs", &rhs);
988 v->Visit("result", &result);
989 v->Visit("identity_element", &identity_element);
990 }
991
SEqualReduce(const CommReducerNode * other,SEqualReducer equal)992 bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
993 return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) &&
994 equal(result, other->result) && equal(identity_element, other->identity_element);
995 }
996
SHashReduce(SHashReducer hash_reduce)997 void SHashReduce(SHashReducer hash_reduce) const {
998 hash_reduce.DefHash(lhs);
999 hash_reduce.DefHash(rhs);
1000 hash_reduce(result);
1001 hash_reduce(identity_element);
1002 }
1003
1004 static constexpr const char* _type_key = "tir.CommReducer";
1005 static constexpr const bool _type_has_method_sequal_reduce = true;
1006 static constexpr const bool _type_has_method_shash_reduce = true;
1007 TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
1008 };
1009
1010 /*!
1011 * \brief Managed reference to CommReducerNode
1012 * \sa CommReducerNode
1013 */
1014 class CommReducer : public ObjectRef {
1015 public:
1016 TVM_DLL CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
1017 Array<PrimExpr> identity_element);
1018
1019 TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode);
1020 };
1021
1022 /*! \brief Reduction operator operator */
1023 class ReduceNode : public PrimExprNode {
1024 public:
1025 /*! \brief The commutative combiner */
1026 CommReducer combiner;
1027 /*! \brief The source operand */
1028 Array<PrimExpr> source;
1029 /*! \brief The init operand */
1030 Array<PrimExpr> init;
1031 /*! \brief The reduction axis */
1032 Array<IterVar> axis;
1033 /*!
1034 * \brief Predicate on the reduction
1035 * Only add the body to reduction if condition is true.
1036 */
1037 PrimExpr condition;
1038 /*! \brief the index of this reduce node */
1039 int value_index;
1040
VisitAttrs(AttrVisitor * v)1041 void VisitAttrs(AttrVisitor* v) {
1042 v->Visit("dtype", &dtype);
1043 v->Visit("combiner", &combiner);
1044 v->Visit("source", &source);
1045 v->Visit("init", &init);
1046 v->Visit("axis", &axis);
1047 v->Visit("condition", &condition);
1048 v->Visit("value_index", &value_index);
1049 }
1050
SEqualReduce(const ReduceNode * other,SEqualReducer equal)1051 bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
1052 // check axis first so IterVars can define the necessary variables.
1053 return equal(dtype, other->dtype) && equal(axis, other->axis) &&
1054 equal(combiner, other->combiner) && equal(source, other->source) &&
1055 equal(init, other->init) && equal(condition, other->condition) &&
1056 equal(value_index, other->value_index);
1057 }
1058
SHashReduce(SHashReducer hash_reduce)1059 void SHashReduce(SHashReducer hash_reduce) const {
1060 hash_reduce(dtype);
1061 hash_reduce(axis);
1062 hash_reduce(combiner);
1063 hash_reduce(source);
1064 hash_reduce(init);
1065 hash_reduce(condition);
1066 hash_reduce(value_index);
1067 }
1068
1069 static constexpr const char* _type_key = "tir.Reduce";
1070 TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
1071 };
1072
1073 /*!
1074 * \brief Managed reference to ReduceNode
1075 * \sa ReduceNode
1076 */
1077 class Reduce : public PrimExpr {
1078 public:
1079 TVM_DLL Reduce(CommReducer combiner, Array<PrimExpr> src, Array<IterVar> rdom, PrimExpr condition,
1080 int value_index, Array<PrimExpr> init);
1081
1082 TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
1083 };
1084
1085 /*! \brief Any shape. */
1086 class AnyNode : public PrimExprNode {
1087 public:
VisitAttrs(AttrVisitor * v)1088 void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); }
1089
SEqualReduce(const AnyNode * other,SEqualReducer equal)1090 bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
1091 return equal(dtype, other->dtype);
1092 }
1093
SHashReduce(SHashReducer hash_reduce)1094 void SHashReduce(SHashReducer hash_reduce) const {}
1095
1096 /*! \brief Convert to var. */
ToVar()1097 Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
1098
1099 static constexpr const char* _type_key = "tir.Any";
1100 TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
1101 };
1102
1103 /*!
1104 * \brief Managed reference to AnyNode
1105 * \sa AnyNode
1106 */
1107 class Any : public PrimExpr {
1108 public:
1109 TVM_DLL Any();
1110
1111 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
1112 };
1113
1114 /*
1115 * \brief Template function to convert Map to unordered_map
1116 * Sometimes useful for API gluing when internal uses unordered_map
1117 * \param dmap The container map
1118 * \return The corresponding unordered_map.
1119 * \tparam K the key of the Map.
1120 * \tparam V the value of the Map.
1121 */
1122 template <typename K, typename V>
as_unordered_map(const Map<K,V> & dmap)1123 inline std::unordered_map<K, V> as_unordered_map(const Map<K, V>& dmap) {
1124 std::unordered_map<K, V> ret;
1125 for (auto kv : dmap) {
1126 ret[kv.first] = kv.second;
1127 }
1128 return ret;
1129 }
1130 } // namespace tir
1131 } // namespace tvm
1132
1133 namespace std {
1134 template <>
1135 struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {};
1136 } // namespace std
1137 #endif // TVM_TIR_EXPR_H_
1138