1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/tir/op.h
22  * \brief Common operators defined for Expr.
23  *
24  * \note Most of the operator defined here perform simple constant folding
25  *   when the type is int32 or int64 for simplifying the index expressions.
26  */
27 // Acknowledgement: Most operator APIs originate from Halide.
28 #ifndef TVM_TIR_OP_H_
29 #define TVM_TIR_OP_H_
30 
31 #include <tvm/ir/op.h>
32 #include <tvm/ir/type.h>
33 #include <tvm/tir/expr.h>
34 #include <tvm/tir/stmt.h>
35 
36 #include <algorithm>
37 #include <limits>
38 #include <type_traits>
39 
40 namespace tvm {
41 
42 // Most common operators can be overloaded by argument type(PrimExpr).
43 // So we put them under the root namespace.
44 // It is also necessary to overload operators for PrimExpr.
45 //
46 // We put more developer oriented APIs -- make_const and is_const under tir
47 // as they are more specific to the tir namespace.
48 
49 /*!
50  * \brief Get the type of the expression under the unified type system.
51  *
52  * This function could return a more refined type than
53  * the runtime type provided by expr->dtype
54  *
55  * \param expr The input parameter.
56  * \return The result type.
57  *
58  * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
59  */
60 TVM_DLL Type GetType(const PrimExpr& expr);
61 
62 /*!
63  * \brief Get the implied DataType for storing values with type during runtime.
64  *
65  * \param type The input type.
66  * \return The result runtime::DataType.
67  *
68  * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
69  */
70 TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
71 
72 /*!
73  * Query the maximum possible value of dtype.
74  * \param dtype The data type.
75  * \return the maximum possible value in this format.
76  */
77 TVM_DLL PrimExpr max_value(const DataType& dtype);
78 
79 /*!
80  * Query the minimum possible value of dtype.
81  * \param dtype The data type.
82  * \return the minimum possible value in this format.
83  */
84 TVM_DLL PrimExpr min_value(const DataType& dtype);
85 
86 /*!
87  * Get the value of infinity.
88  * \param dtype The data type.
89  * \return the infinity value in this format.
90  */
91 TVM_DLL PrimExpr infinity(const DataType& dtype);
92 
93 /*!
94  * \brief cast value to type.
95  *
96  * \param t the target type.
97  * \param value The value
98  * \return The result expression.
99  * \note This function may return value if the type is the same.
100  */
101 TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value);
102 /*!
103  * \brief perform reinterpret cast value to type.
104  *
105  * \param t the target type.
106  * \param value The value
107  * \return The result expression.
108  * \note This function may return value if the type is the same.
109  */
110 TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value);
111 /*!
112  * \brief add operator
113  *
114  * \param a left operand
115  * \param b right operand
116  * \return The result expression.
117  * \note this function does eager constant folding for
118  *       index types(int32, int64) when possible.
119  */
120 TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
121 /*!
122  * \brief subtraction operator
123  *
124  * \param a left operand
125  * \param b right operand
126  * \return The result expression.
127  * \note this function does eager constant folding for
128  *       index types(int32, int64) when possible.
129  */
130 TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
131 /*!
132  * \brief negation.
133  *
134  * \param a input.
135  * \return The result expression.
136  * \note this function does eager constant folding for
137  *       index types(int32, int64) when possible.
138  */
139 TVM_DLL PrimExpr operator-(PrimExpr a);
140 /*!
141  * \brief multiplication operator
142  *
143  * \param a left operand
144  * \param b right operand
145  * \return The result expression.
146  * \note this function does eager constant folding for
147  *       index types(int32, int64) when possible.
148  */
149 TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
150 /*!
151  * \brief division operator
152  *
153  * \param a left operand
154  * \param b right operand
155  * \return The result expression.
156  * \note this function does eager constant folding for
157  *       index types(int32, int64) when possible.
158  */
159 TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
160 /*!
161  * \brief left shift operator
162  *
163  * \param a left operand
164  * \param b right operand
165  * \return The result expression.
166  * \note this function does eager constant folding for
167  *       index types(int32, int64) when possible.
168  */
169 TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
170 /*!
171  * \brief right shift operator
172  *
173  * \param a left operand
174  * \param b right operand
175  * \return The result expression.
176  * \note this function does eager constant folding for
177  *       index types(int32, int64) when possible.
178  */
179 TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
180 /*!
181  * \brief greater
182  *
183  * \param a left operand
184  * \param b right operand
185  * \return The result expression.
186  * \note this function does eager constant folding for
187  *       index types(int32, int64) when possible.
188  */
189 TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
190 /*!
191  * \brief greater_equal
192  *
193  * \param a left operand
194  * \param b right operand
195  * \return The result expression.
196  * \note this function does eager constant folding for
197  *       index types(int32, int64) when possible.
198  */
199 TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
200 /*!
201  * \brief less
202  *
203  * \param a left operand
204  * \param b right operand
205  * \return The result expression.
206  * \note this function does eager constant folding for
207  *       index types(int32, int64) when possible.
208  */
209 TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
210 /*!
211  * \brief less_equal
212  *
213  * \param a left operand
214  * \param b right operand
215  * \return The result expression.
216  * \note this function does eager constant folding for
217  *       index types(int32, int64) when possible.
218  */
219 TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
220 /*!
221  * \brief equal
222  *
223  * \param a left operand
224  * \param b right operand
225  * \return The result expression.
226  * \note this function does eager constant folding for
227  *       index types(int32, int64) when possible.
228  */
229 TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
230 /*!
231  * \brief not_equal
232  *
233  * \param a left operand
234  * \param b right operand
235  * \return The result expression.
236  * \note this function does eager constant folding for
237  *       index types(int32, int64) when possible.
238  */
239 TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
240 /*!
241  * \brief and
242  *
243  * \param a left operand
244  * \param b right operand
245  * \return The result expression.
246  * \note This operator does eager constant folding.
247  */
248 TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
249 /*!
250  * \brief or
251  *
252  * \param a left operand
253  * \param b right operand
254  * \return The result expression.
255  * \note This operator does eager constant folding.
256  */
257 TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
258 /*!
259  * \brief not
260  *
261  * \param a left operand
262  * \return The result expression.
263  * \note This operator does eager constant folding.
264  */
265 TVM_DLL PrimExpr operator!(PrimExpr a);
266 /*!
267  * \brief compute division in C semantics.
268  *
269  * a / b as in C/C++.
270  *
271  * When operands are integers, it directly corresponds to truncdiv.
272  *
273  * \param a left operand
274  * \param b right operand
275  * \return The result expression.
276  * \note this function does eager constant folding for
277  *       index types(int32, int64) when possible.
278  */
279 TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b);
280 /*!
281  * \brief compute trunc(a / b)
282  *
283  * This is the default integer division behavior in C.
284  *
285  * \param a left operand
286  * \param b right operand
287  * \return The result expression.
288  * \note this function does eager constant folding for
289  *       index types(int32, int64) when possible.
290  */
291 TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b);
292 /*!
293  * \brief compute the remainder of truncdiv
294  *
295  * This is the default integer division behavior in C.
296  *
297  * \param a left operand
298  * \param b right operand
299  * \return The result expression.
300  * \note this function does eager constant folding for
301  *       index types(int32, int64) when possible.
302  */
303 TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b);
304 /*!
305  * \brief compute floor(a / b) where a and b are non-negative.
306  *
307  * Use this function for index split calculation.
308  *
309  * This function might take advantage of the fact
310  * that a and b are non-negative.
311  *
312  * \param a left operand
313  * \param b right operand
314  * \return The result expression.
315  * \note this function does eager constant folding for
316  *       index types(int32, int64) when possible.
317  */
318 TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b);
319 /*!
320  * \brief compute the remainder floor(a / b) where a and b are non-negative.
321  *
322  * Use this function for index split calculation.
323  * This function might take advantage of the fact
324  * that a and b are non-negative.
325  *
326  * \param a left operand
327  * \param b right operand
328  * \return The result expression.
329  * \note this function does eager constant folding for
330  *       index types(int32, int64) when possible.
331  */
332 TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b);
333 /*!
334  * \brief compute floor(a / b)
335  *
336  * \param a left operand
337  * \param b right operand
338  * \return The result expression.
339  * \note this function does eager constant folding for
340  *       index types(int32, int64) when possible.
341  */
342 TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b);
343 /*!
344  * \brief compute the remainder of floordiv
345  *
346  * \param a left operand
347  * \param b right operand
348  * \return The result expression.
349  * \note this function does eager constant folding for
350  *       index types(int32, int64) when possible.
351  */
352 TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b);
353 /*!
354  * \brief take maximum of two values
355  *
356  * \param a left operand
357  * \param b right operand
358  * \return The result expression.
359  * \note this function does eager constant folding for
360  *       index types(int32, int64) when possible.
361  */
362 TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b);
363 /*!
364  * \brief take minimum of two values
365  *
366  * \param a left operand
367  * \param b right operand
368  * \return The result expression.
369  * \note this function does eager constant folding for
370  *       index types(int32, int64) when possible.
371  */
372 TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b);
373 /*!
374  * \brief take bitwise and of two values
375  *
376  * \param a left operand
377  * \param b right operand
378  * \return The result expression.
379  * \note this function does eager constant folding for
380  *       index types(int32, int64) when possible.
381  */
382 TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
383 /*!
384  * \brief take bitwise or of two values
385  *
386  * \param a left operand
387  * \param b right operand
388  * \return The result expression.
389  * \note this function does eager constant folding for
390  *       index types(int32, int64) when possible.
391  */
392 TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
393 /*!
394  * \brief take bitwise xor of two values
395  *
396  * \param a left operand
397  * \param b right operand
398  * \return The result expression.
399  * \note this function does eager constant folding for
400  *       index types(int32, int64) when possible.
401  */
402 TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
403 /*!
404  * \brief take bitwise negation of two values
405  *
406  * \param a the input expression.
407  * \return The result expression.
408  * \note this function does eager constant folding for
409  *       index types(int32, int64) when possible.
410  */
411 TVM_DLL PrimExpr operator~(PrimExpr a);
412 /*!
413  * \brief Conditional expression.
414  *
415  * \param cond The condition
416  * \param true_value The value when results are true.
417  * \param false_value The value when results are false.
418  * \return The result expression.
419  * \note this function does eager constant folding for
420  *       index types(int32, int64) when possible.
421  */
422 TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value);
423 /*!
424  * \brief Mark condition as likely.
425  * \param cond The condition
426  * \return The marked expression.
427  */
428 TVM_DLL PrimExpr likely(PrimExpr cond);
429 /*!
430  * \brief Calculate power(x, y)
431  * \param x The left operand.
432  * \param y The right operand.
433  */
434 TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y);
435 /*!
436  * \brief Calculate absolute value of x.
437  * \param x The input data
438  *
439  * \return The aboslute value of input data x
440  */
441 TVM_DLL PrimExpr abs(PrimExpr x);
442 /*!
443  * \brief Check if x is NaN.
444  * \param x The input data
445  * \return The result expression.
446  */
447 TVM_DLL PrimExpr isnan(PrimExpr x);
448 
449 /*!
450  * \brief Check if x is finite.
451  * \param x The input data
452  * \return The result expression.
453  */
454 TVM_DLL PrimExpr isfinite(PrimExpr x);
455 
456 /*!
457  * \brief Check if x is infinite.
458  * \param x The input data
459  * \return The result expression.
460  */
461 TVM_DLL PrimExpr isinf(PrimExpr x);
462 
463 /*!
464  * \brief sum of of source expression over axis
465  * \param source The source expression.
466  * \param axis List of iteration variables that will be used for reduction.
467  * \param init The value with which to initialize the output.
468  * \return The result.
469  */
470 TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
471 
472 /*!
473  * \brief logical And of of source expression over axis
474  * \param source The source expression.
475  * \param axis List of iteration variables that will be used for reduction.
476  * \param init The value with which to initialize the output.
477  */
478 TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
479 
480 /*!
481  * \brief logical Or of of source expression over axis
482  * \param source The source expression.
483  * \param axis List of iteration variables that will be used for reduction.
484  * \param init The value with which to initialize the output.
485  * \return The result.
486  */
487 TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
488 
489 /*!
490  * \brief max of of source expression over axis
491  * \param source The source expression.
492  * \param axis List of iteration variables that will be used for reduction.
493  * \param init The value with which to initialize the output.
494  * \return The result.
495  */
496 TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
497 
498 /*!
499  * \brief max of of source expression over axis
500  * \param source The source expression.
501  * \param axis List of iteration variables that will be used for reduction.
502  * \param init The value with which to initialize the output.
503  * \return The result.
504  */
505 TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
506 
507 /*!
508  * \brief product of of source expression over axis
509  * \param source The source expression.
510  * \param axis List of iteration variables that will be used for reduction.
511  * \param init The value with which to initialize the output.
512  * \return The result.
513  */
514 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
515 
516 /*!
517  * \brief Calculate floor(x)
518  * \param x The input expression.
519  * \return The result expression.
520  */
521 TVM_DLL PrimExpr floor(PrimExpr x);
522 
523 /*!
524  * \brief Calculate ceil(x)
525  * \param x The input expression.
526  * \return The result expression.
527  */
528 TVM_DLL PrimExpr ceil(PrimExpr x);
529 
530 /*!
531  * \brief Calculate round(x)
532  * \param x The input expression.
533  * \return The result expression.
534  */
535 TVM_DLL PrimExpr round(PrimExpr x);
536 
537 /*!
538  * \brief Calculates std::nearbyint(x)
539  * \param x The input expression.
540  * \return The result expression.
541  * This is a faster alternate to round.
542  */
543 TVM_DLL PrimExpr nearbyint(PrimExpr x);
544 
545 /*!
546  * \brief Calculate trunc(x)
547  * \param x The input expression.
548  * \return The result expression.
549  */
550 TVM_DLL PrimExpr trunc(PrimExpr x);
551 
552 /*!
553  * \brief Construct a large uint constant by its low 32 bits and high 32bits.
554  * \param dtype The final data type.
555  * \param low The lower 32 bits.
556  * \param high The higher 32 bits.
557  * \return The constructed expression.
558  */
559 TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
560 
561 /*!
562  * \brief Execute a multiplication between two Q-numbers x and y
563  * followed by a right shift s. The mathematical expression is:
564  *
565  *    out = round(x*y*2^-s)
566  *
567  * Please note that the two Q-numbers x and y are supposed to have
568  * the same number of fractional bits q.
569  *
570  * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
571  *
572  * The rounding rule is to the nearest value, rounding half up
573  * (i.e., round(x.1) = x and round (x.5) = x+1)
574  * \param x first Q-number
575  * \param y second Q-number
576  * \param q number of fractional bits in x and y. Needs to be > 0
577  * \param s integer right shift
578  * \return The constructed expression.
579  */
580 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s);
581 
582 // Intrinsic operators
583 #define TVM_DECLARE_INTRIN_UNARY(OpName)           \
584   inline PrimExpr OpName(PrimExpr x) {             \
585     static const Op& op = Op::Get("tir." #OpName); \
586     return tir::Call(x.dtype(), op, {x});          \
587   }
588 
589 TVM_DECLARE_INTRIN_UNARY(exp);
590 TVM_DECLARE_INTRIN_UNARY(exp2);
591 TVM_DECLARE_INTRIN_UNARY(exp10);
592 TVM_DECLARE_INTRIN_UNARY(erf);
593 TVM_DECLARE_INTRIN_UNARY(tanh);
594 TVM_DECLARE_INTRIN_UNARY(sigmoid);
595 TVM_DECLARE_INTRIN_UNARY(sqrt);
596 TVM_DECLARE_INTRIN_UNARY(rsqrt);
597 TVM_DECLARE_INTRIN_UNARY(log);
598 TVM_DECLARE_INTRIN_UNARY(log2);
599 TVM_DECLARE_INTRIN_UNARY(log10);
600 TVM_DECLARE_INTRIN_UNARY(popcount);
601 TVM_DECLARE_INTRIN_UNARY(tan);
602 TVM_DECLARE_INTRIN_UNARY(cos);
603 TVM_DECLARE_INTRIN_UNARY(cosh);
604 TVM_DECLARE_INTRIN_UNARY(sin);
605 TVM_DECLARE_INTRIN_UNARY(sinh);
606 TVM_DECLARE_INTRIN_UNARY(asin);
607 TVM_DECLARE_INTRIN_UNARY(acos);
608 TVM_DECLARE_INTRIN_UNARY(atan);
609 TVM_DECLARE_INTRIN_UNARY(acosh);
610 TVM_DECLARE_INTRIN_UNARY(asinh);
611 TVM_DECLARE_INTRIN_UNARY(atanh);
612 
613 #define TVM_DECLARE_INTRIN_BINARY(OpName)          \
614   inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
615     static const Op& op = Op::Get("tir." #OpName); \
616     return tir::Call(x.dtype(), op, {x, y});       \
617   }
618 
619 TVM_DECLARE_INTRIN_BINARY(atan2);
620 TVM_DECLARE_INTRIN_BINARY(nextafter);
621 TVM_DECLARE_INTRIN_BINARY(copysign);
622 TVM_DECLARE_INTRIN_BINARY(hypot);
623 TVM_DECLARE_INTRIN_BINARY(ldexp);
624 
625 namespace tir {
626 
627 /*!
628  * \brief Check if type is a pointer to a runtime element type.
629  * \param type The type to be checked.
630  * \param element_type The corresponding element type.
631  * \return The check results
632  */
IsPointerType(const Type & type,const DataType & element_type)633 inline bool IsPointerType(const Type& type, const DataType& element_type) {
634   if (!type.defined()) return false;
635   if (const auto* ptr_type = type.as<PointerTypeNode>()) {
636     if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
637       return prim_type->dtype == element_type;
638     }
639   }
640   return false;
641 }
642 
643 /*!
644  * \brief Make a const value with certain data type.
645  * \param t The target type.
646  * \param value The input value
647  * \return the result expression.
648  * \tparam ValueType The constant value type
649  */
650 template <typename ValueType,
651           typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
652 inline PrimExpr make_const(DataType t, ValueType value);
653 /*!
654  * \brief Make a const zero expr.
655  * \param t The target type.
656  * \return the result expression.
657  */
658 inline PrimExpr make_zero(DataType t);
659 /*!
660  * \brief Make a constant true expression.
661  * \param lanes The number of lanes in the bool
662  * \return The result expression.
663  */
664 inline PrimExpr const_true(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 1); }
665 /*!
666  * \brief Make a constant false expression.
667  * \param lanes The number of lanes in the bool
668  * \return The result expression.
669  */
670 inline PrimExpr const_false(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 0); }
671 /*!
672  * \brief Get x as constant int expression.
673  * \param x The expression
674  * \return the address to the int expression,
675  *         return nullptr, if x is not IntImm.
676  */
as_const_int(const PrimExpr & x)677 inline const int64_t* as_const_int(const PrimExpr& x) {
678   if (!x.defined()) return nullptr;
679   if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) {
680     return &(op->value);
681   } else {
682     return nullptr;
683   }
684 }
685 
686 /*!
687  * \brief Check whether x is a constant integer expression.
688  * \param x The input argument
689  * \param value the value to be compared against.
690  * \return whether x is constant expression.
691  */
692 inline bool is_const_int(const PrimExpr& x, int64_t value);
693 
694 /*!
695  * \brief Check whether stmt is nop.
696  * \param stmt The input statement
697  * \return whether stmt is nop
698  */
699 inline bool is_no_op(const tir::Stmt& stmt);
700 
701 /*!
702  * \brief Check whether x is a constant integer 1
703  * \param x The input argument.
704  * \note This only return true for integer types.
705  * \return whether x is constant 1
706  */
is_one(const PrimExpr & x)707 inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
708 
709 /*!
710  * \brief Check whether x is a constant integer 0
711  * \param x The input argument
712  * \return whether x is constant 0
713  * \note This only return true for integer types.
714  */
is_zero(const PrimExpr & x)715 inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
716 
717 /*!
718  * \brief Check whether x is an integer constant.
719  * \note This only return true for integer types.
720  * \return whether x is constant
721  */
722 inline bool is_const_int(const PrimExpr& x);
723 
724 /*!
725  * \brief Check whether x is an integer/float constant.
726  * \note This only return true for integer types.
727  * \return whether x is constant
728  */
729 inline bool is_const_number(const PrimExpr& x);
730 
731 /*!
732  * \brief Left fold.
733  * \param freduce The reduction function.
734  * \param init_value The initial value.
735  * \param values The values to be folded.
736  * \return The result.
737  * \tparam FReduce The type of the reduction.
738  */
739 template <typename FReduce>
740 inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values);
741 
742 /*!
743  * \brief Check whether x is a constant power of two
744  * If x is power of two, write the power to the shift.
745  *
746  * \param x The input expression.
747  * \param shift The output shift if x is power of two.
748  * \return whether x is constant power of two
749  */
750 TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
751 
752 // Implementation details after this
is_const_int(const PrimExpr & x)753 inline bool is_const_int(const PrimExpr& x) {
754   if (x.as<tir::IntImmNode>()) {
755     return true;
756   } else if (const auto* op = x.as<tir::BroadcastNode>()) {
757     const PrimExpr& val = op->value;
758     if (val.as<tir::IntImmNode>()) {
759       return true;
760     }
761   }
762   return false;
763 }
764 
is_const_number(const PrimExpr & x)765 inline bool is_const_number(const PrimExpr& x) {
766   if (x.as<tir::IntImmNode>()) {
767     return true;
768   } else if (x.as<tir::FloatImmNode>()) {
769     return true;
770   } else if (const auto* op = x.as<tir::BroadcastNode>()) {
771     return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>());
772   }
773   return false;
774 }
775 
is_positive_const(const PrimExpr & a)776 inline bool is_positive_const(const PrimExpr& a) {
777   if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
778     return op->value > 0;
779   } else {
780     return false;
781   }
782 }
783 
is_negative_const(const PrimExpr & a)784 inline bool is_negative_const(const PrimExpr& a) {
785   if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
786     return op->value < 0;
787   } else {
788     return false;
789   }
790 }
791 
is_const_int(const PrimExpr & x,int64_t value)792 inline bool is_const_int(const PrimExpr& x, int64_t value) {
793   if (const auto* op = x.as<tir::IntImmNode>()) {
794     return op->value == value;
795   } else if (const auto* op = x.as<tir::BroadcastNode>()) {
796     const PrimExpr& val = op->value;
797     if (const auto* opv = val.as<tir::IntImmNode>()) {
798       return opv->value == value;
799     }
800   }
801   return false;
802 }
803 
is_no_op(const tir::Stmt & stmt)804 inline bool is_no_op(const tir::Stmt& stmt) {
805   if (!stmt.defined()) return true;
806   if (const auto* op = stmt.as<tir::EvaluateNode>()) {
807     return is_const_int(op->value);
808   }
809   if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
810     return op->seq.size() == 0;
811   }
812   return false;
813 }
814 
815 template <typename ValueType>
MakeConstScalar(DataType t,ValueType value)816 inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
817   if (t.is_int()) return IntImm(t, static_cast<int64_t>(value));
818   if (t.is_uint()) {
819     // Use IntImm if it is a small integer
820     uint64_t uval = static_cast<uint64_t>(value);
821     if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
822       return IntImm(t, static_cast<int64_t>(value));
823     } else {
824       uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
825       uint64_t low = uval & mask;
826       uint64_t high = uval >> 32U;
827       return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
828     }
829   }
830   if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value));
831   // For now, we store const scalar values of custom datatypes within doubles; later, during the
832   // datatypes lowering pass, we will lower the value to its true representation in the format
833   // specified by the datatype.
834   // TODO(gus) when do we need to start worrying about doubles not being precise enough?
835   if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
836     return FloatImm(t, static_cast<double>(value));
837   }
838   LOG(FATAL) << "cannot make const for type " << t;
839   return PrimExpr();
840 }
841 
842 template <typename ValueType, typename>
make_const(DataType t,ValueType value)843 inline PrimExpr make_const(DataType t, ValueType value) {
844   if (t.lanes() == 1) {
845     return MakeConstScalar(t, value);
846   } else {
847     return tir::Broadcast(MakeConstScalar(t.element_of(), value), t.lanes());
848   }
849 }
850 
make_zero(DataType t)851 inline PrimExpr make_zero(DataType t) {
852   if (t.is_handle()) {
853     return reinterpret(t, make_const(DataType::UInt(64), 0));
854   }
855   return make_const(t, 0);
856 }
857 
858 template <typename FReduce>
foldl(FReduce freduce,PrimExpr init_value,const Array<PrimExpr> & values)859 inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values) {
860   for (PrimExpr val : values) {
861     init_value = freduce(init_value, val);
862   }
863   return init_value;
864 }
865 
866 }  // namespace tir
867 
868 // additional const expression overloading
869 #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
870   inline PrimExpr Name(PrimExpr& a, PrimExpr b) {   \
871     a = OpFunc(a, b);                               \
872     return a;                                       \
873   }
874 
875 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name)                                   \
876   inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
877   inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
878   inline PrimExpr Name(int a, const PrimExpr& b) {                                  \
879     return Name(tir::make_const(b.dtype(), a), b);                                  \
880   }                                                                                 \
881   inline PrimExpr Name(const PrimExpr& a, int b) {                                  \
882     return Name(a, tir::make_const(a.dtype(), b));                                  \
883   }                                                                                 \
884   inline PrimExpr Name(const PrimExpr& a, double b) {                               \
885     return Name(a, tir::make_const(DataType::Float(64), b));                        \
886   }
887 
888 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name)                             \
889   inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \
890   inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); }
891 
892 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
893   inline PrimExpr Name(const PrimExpr& a, int b) { \
894     return Name(a, tir::make_const(a.dtype(), b)); \
895   }                                                \
896   inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); }
897 
898 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
899 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
900 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
901 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
902 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
903 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
904 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
905 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
906 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
907 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>);  // NOLINT(*)
908 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
909 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<);  // NOLINT(*)
910 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
911 // integer related ops
912 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexdiv);
913 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexmod);
914 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
915 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
916 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
917 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
918 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>);  // NOLINT(*)
919 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<);  // NOLINT(*)
920 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
921 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|);
922 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
923 // logical ops
924 TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
925 TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
926 
927 /*!
928  * \brief Helper function to raise a compiler error about division ambiguity.
929  * \note The call to this function will always results in a compiler error.
930  * \tparam TA Any class type.
931  */
932 template <typename TA>
DivAmbiguityError(const TA & a)933 inline void DivAmbiguityError(const TA& a) {
934   constexpr bool div_ambiguity = !std::is_class<TA>::value;
935   static_assert(div_ambiguity,
936                 "TVM supports multiple types of integer divisions, "
937                 "please call div, indexdiv/indexmod, "
938                 "floordiv/floormod or truncdiv/truncmod directly "
939                 "to avoid ambiguity in the code. "
940                 "Checkout these functions in expr_operator.h.");
941 }
942 
943 // The following code are not intended to be used in the codebase.
944 // Instead, they generate clear compiler errors that ask developers
945 // to use the specific division function.
946 // The second template argument is necessary to make sure the
947 // code compiles lazily by the compiler during invocation.
948 template <typename TB>
949 inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
950   DivAmbiguityError(a);
951   return a;
952 }
953 
954 template <typename TB>
955 inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
956   DivAmbiguityError(a);
957   return a;
958 }
959 
960 template <typename TB>
961 inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
962   DivAmbiguityError(a);
963   return a;
964 }
965 }  // namespace tvm
966 #endif  // TVM_TIR_OP_H_
967