1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/tir/op.h
22 * \brief Common operators defined for Expr.
23 *
24 * \note Most of the operator defined here perform simple constant folding
25 * when the type is int32 or int64 for simplifying the index expressions.
26 */
27 // Acknowledgement: Most operator APIs originate from Halide.
28 #ifndef TVM_TIR_OP_H_
29 #define TVM_TIR_OP_H_
30
31 #include <tvm/ir/op.h>
32 #include <tvm/ir/type.h>
33 #include <tvm/tir/expr.h>
34 #include <tvm/tir/stmt.h>
35
36 #include <algorithm>
37 #include <limits>
38 #include <type_traits>
39
40 namespace tvm {
41
42 // Most common operators can be overloaded by argument type(PrimExpr).
43 // So we put them under the root namespace.
44 // It is also necessary to overload operators for PrimExpr.
45 //
46 // We put more developer oriented APIs -- make_const and is_const under tir
47 // as they are more specific to the tir namespace.
48
49 /*!
50 * \brief Get the type of the expression under the unified type system.
51 *
52 * This function could return a more refined type than
53 * the runtime type provided by expr->dtype
54 *
55 * \param expr The input parameter.
56 * \return The result type.
57 *
58 * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
59 */
60 TVM_DLL Type GetType(const PrimExpr& expr);
61
62 /*!
63 * \brief Get the implied DataType for storing values with type during runtime.
64 *
65 * \param type The input type.
66 * \return The result runtime::DataType.
67 *
68 * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
69 */
70 TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
71
72 /*!
73 * Query the maximum possible value of dtype.
74 * \param dtype The data type.
75 * \return the maximum possible value in this format.
76 */
77 TVM_DLL PrimExpr max_value(const DataType& dtype);
78
79 /*!
80 * Query the minimum possible value of dtype.
81 * \param dtype The data type.
82 * \return the minimum possible value in this format.
83 */
84 TVM_DLL PrimExpr min_value(const DataType& dtype);
85
86 /*!
87 * Get the value of infinity.
88 * \param dtype The data type.
89 * \return the infinity value in this format.
90 */
91 TVM_DLL PrimExpr infinity(const DataType& dtype);
92
93 /*!
94 * \brief cast value to type.
95 *
96 * \param t the target type.
97 * \param value The value
98 * \return The result expression.
99 * \note This function may return value if the type is the same.
100 */
101 TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value);
102 /*!
103 * \brief perform reinterpret cast value to type.
104 *
105 * \param t the target type.
106 * \param value The value
107 * \return The result expression.
108 * \note This function may return value if the type is the same.
109 */
110 TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value);
111 /*!
112 * \brief add operator
113 *
114 * \param a left operand
115 * \param b right operand
116 * \return The result expression.
117 * \note this function does eager constant folding for
118 * index types(int32, int64) when possible.
119 */
120 TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
121 /*!
122 * \brief subtraction operator
123 *
124 * \param a left operand
125 * \param b right operand
126 * \return The result expression.
127 * \note this function does eager constant folding for
128 * index types(int32, int64) when possible.
129 */
130 TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
131 /*!
132 * \brief negation.
133 *
134 * \param a input.
135 * \return The result expression.
136 * \note this function does eager constant folding for
137 * index types(int32, int64) when possible.
138 */
139 TVM_DLL PrimExpr operator-(PrimExpr a);
140 /*!
141 * \brief multiplication operator
142 *
143 * \param a left operand
144 * \param b right operand
145 * \return The result expression.
146 * \note this function does eager constant folding for
147 * index types(int32, int64) when possible.
148 */
149 TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
150 /*!
151 * \brief division operator
152 *
153 * \param a left operand
154 * \param b right operand
155 * \return The result expression.
156 * \note this function does eager constant folding for
157 * index types(int32, int64) when possible.
158 */
159 TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
160 /*!
161 * \brief left shift operator
162 *
163 * \param a left operand
164 * \param b right operand
165 * \return The result expression.
166 * \note this function does eager constant folding for
167 * index types(int32, int64) when possible.
168 */
169 TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
170 /*!
171 * \brief right shift operator
172 *
173 * \param a left operand
174 * \param b right operand
175 * \return The result expression.
176 * \note this function does eager constant folding for
177 * index types(int32, int64) when possible.
178 */
179 TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
180 /*!
181 * \brief greater
182 *
183 * \param a left operand
184 * \param b right operand
185 * \return The result expression.
186 * \note this function does eager constant folding for
187 * index types(int32, int64) when possible.
188 */
189 TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
190 /*!
191 * \brief greater_equal
192 *
193 * \param a left operand
194 * \param b right operand
195 * \return The result expression.
196 * \note this function does eager constant folding for
197 * index types(int32, int64) when possible.
198 */
199 TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
200 /*!
201 * \brief less
202 *
203 * \param a left operand
204 * \param b right operand
205 * \return The result expression.
206 * \note this function does eager constant folding for
207 * index types(int32, int64) when possible.
208 */
209 TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
210 /*!
211 * \brief less_equal
212 *
213 * \param a left operand
214 * \param b right operand
215 * \return The result expression.
216 * \note this function does eager constant folding for
217 * index types(int32, int64) when possible.
218 */
219 TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
220 /*!
221 * \brief equal
222 *
223 * \param a left operand
224 * \param b right operand
225 * \return The result expression.
226 * \note this function does eager constant folding for
227 * index types(int32, int64) when possible.
228 */
229 TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
230 /*!
231 * \brief not_equal
232 *
233 * \param a left operand
234 * \param b right operand
235 * \return The result expression.
236 * \note this function does eager constant folding for
237 * index types(int32, int64) when possible.
238 */
239 TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
240 /*!
241 * \brief and
242 *
243 * \param a left operand
244 * \param b right operand
245 * \return The result expression.
246 * \note This operator does eager constant folding.
247 */
248 TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
249 /*!
250 * \brief or
251 *
252 * \param a left operand
253 * \param b right operand
254 * \return The result expression.
255 * \note This operator does eager constant folding.
256 */
257 TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
258 /*!
259 * \brief not
260 *
261 * \param a left operand
262 * \return The result expression.
263 * \note This operator does eager constant folding.
264 */
265 TVM_DLL PrimExpr operator!(PrimExpr a);
266 /*!
267 * \brief compute division in C semantics.
268 *
269 * a / b as in C/C++.
270 *
271 * When operands are integers, it directly corresponds to truncdiv.
272 *
273 * \param a left operand
274 * \param b right operand
275 * \return The result expression.
276 * \note this function does eager constant folding for
277 * index types(int32, int64) when possible.
278 */
279 TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b);
280 /*!
281 * \brief compute trunc(a / b)
282 *
283 * This is the default integer division behavior in C.
284 *
285 * \param a left operand
286 * \param b right operand
287 * \return The result expression.
288 * \note this function does eager constant folding for
289 * index types(int32, int64) when possible.
290 */
291 TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b);
292 /*!
293 * \brief compute the remainder of truncdiv
294 *
295 * This is the default integer division behavior in C.
296 *
297 * \param a left operand
298 * \param b right operand
299 * \return The result expression.
300 * \note this function does eager constant folding for
301 * index types(int32, int64) when possible.
302 */
303 TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b);
304 /*!
305 * \brief compute floor(a / b) where a and b are non-negative.
306 *
307 * Use this function for index split calculation.
308 *
309 * This function might take advantage of the fact
310 * that a and b are non-negative.
311 *
312 * \param a left operand
313 * \param b right operand
314 * \return The result expression.
315 * \note this function does eager constant folding for
316 * index types(int32, int64) when possible.
317 */
318 TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b);
319 /*!
320 * \brief compute the remainder floor(a / b) where a and b are non-negative.
321 *
322 * Use this function for index split calculation.
323 * This function might take advantage of the fact
324 * that a and b are non-negative.
325 *
326 * \param a left operand
327 * \param b right operand
328 * \return The result expression.
329 * \note this function does eager constant folding for
330 * index types(int32, int64) when possible.
331 */
332 TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b);
333 /*!
334 * \brief compute floor(a / b)
335 *
336 * \param a left operand
337 * \param b right operand
338 * \return The result expression.
339 * \note this function does eager constant folding for
340 * index types(int32, int64) when possible.
341 */
342 TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b);
343 /*!
344 * \brief compute the remainder of floordiv
345 *
346 * \param a left operand
347 * \param b right operand
348 * \return The result expression.
349 * \note this function does eager constant folding for
350 * index types(int32, int64) when possible.
351 */
352 TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b);
353 /*!
354 * \brief take maximum of two values
355 *
356 * \param a left operand
357 * \param b right operand
358 * \return The result expression.
359 * \note this function does eager constant folding for
360 * index types(int32, int64) when possible.
361 */
362 TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b);
363 /*!
364 * \brief take minimum of two values
365 *
366 * \param a left operand
367 * \param b right operand
368 * \return The result expression.
369 * \note this function does eager constant folding for
370 * index types(int32, int64) when possible.
371 */
372 TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b);
373 /*!
374 * \brief take bitwise and of two values
375 *
376 * \param a left operand
377 * \param b right operand
378 * \return The result expression.
379 * \note this function does eager constant folding for
380 * index types(int32, int64) when possible.
381 */
382 TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
383 /*!
384 * \brief take bitwise or of two values
385 *
386 * \param a left operand
387 * \param b right operand
388 * \return The result expression.
389 * \note this function does eager constant folding for
390 * index types(int32, int64) when possible.
391 */
392 TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
393 /*!
394 * \brief take bitwise xor of two values
395 *
396 * \param a left operand
397 * \param b right operand
398 * \return The result expression.
399 * \note this function does eager constant folding for
400 * index types(int32, int64) when possible.
401 */
402 TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
403 /*!
404 * \brief take bitwise negation of two values
405 *
406 * \param a the input expression.
407 * \return The result expression.
408 * \note this function does eager constant folding for
409 * index types(int32, int64) when possible.
410 */
411 TVM_DLL PrimExpr operator~(PrimExpr a);
412 /*!
413 * \brief Conditional expression.
414 *
415 * \param cond The condition
416 * \param true_value The value when results are true.
417 * \param false_value The value when results are false.
418 * \return The result expression.
419 * \note this function does eager constant folding for
420 * index types(int32, int64) when possible.
421 */
422 TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value);
423 /*!
424 * \brief Mark condition as likely.
425 * \param cond The condition
426 * \return The marked expression.
427 */
428 TVM_DLL PrimExpr likely(PrimExpr cond);
429 /*!
430 * \brief Calculate power(x, y)
431 * \param x The left operand.
432 * \param y The right operand.
433 */
434 TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y);
435 /*!
436 * \brief Calculate absolute value of x.
437 * \param x The input data
438 *
439 * \return The aboslute value of input data x
440 */
441 TVM_DLL PrimExpr abs(PrimExpr x);
442 /*!
443 * \brief Check if x is NaN.
444 * \param x The input data
445 * \return The result expression.
446 */
447 TVM_DLL PrimExpr isnan(PrimExpr x);
448
449 /*!
450 * \brief Check if x is finite.
451 * \param x The input data
452 * \return The result expression.
453 */
454 TVM_DLL PrimExpr isfinite(PrimExpr x);
455
456 /*!
457 * \brief Check if x is infinite.
458 * \param x The input data
459 * \return The result expression.
460 */
461 TVM_DLL PrimExpr isinf(PrimExpr x);
462
463 /*!
464 * \brief sum of of source expression over axis
465 * \param source The source expression.
466 * \param axis List of iteration variables that will be used for reduction.
467 * \param init The value with which to initialize the output.
468 * \return The result.
469 */
470 TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
471
472 /*!
473 * \brief logical And of of source expression over axis
474 * \param source The source expression.
475 * \param axis List of iteration variables that will be used for reduction.
476 * \param init The value with which to initialize the output.
477 */
478 TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
479
480 /*!
481 * \brief logical Or of of source expression over axis
482 * \param source The source expression.
483 * \param axis List of iteration variables that will be used for reduction.
484 * \param init The value with which to initialize the output.
485 * \return The result.
486 */
487 TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
488
489 /*!
490 * \brief max of of source expression over axis
491 * \param source The source expression.
492 * \param axis List of iteration variables that will be used for reduction.
493 * \param init The value with which to initialize the output.
494 * \return The result.
495 */
496 TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
497
498 /*!
499 * \brief max of of source expression over axis
500 * \param source The source expression.
501 * \param axis List of iteration variables that will be used for reduction.
502 * \param init The value with which to initialize the output.
503 * \return The result.
504 */
505 TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
506
507 /*!
508 * \brief product of of source expression over axis
509 * \param source The source expression.
510 * \param axis List of iteration variables that will be used for reduction.
511 * \param init The value with which to initialize the output.
512 * \return The result.
513 */
514 TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {});
515
516 /*!
517 * \brief Calculate floor(x)
518 * \param x The input expression.
519 * \return The result expression.
520 */
521 TVM_DLL PrimExpr floor(PrimExpr x);
522
523 /*!
524 * \brief Calculate ceil(x)
525 * \param x The input expression.
526 * \return The result expression.
527 */
528 TVM_DLL PrimExpr ceil(PrimExpr x);
529
530 /*!
531 * \brief Calculate round(x)
532 * \param x The input expression.
533 * \return The result expression.
534 */
535 TVM_DLL PrimExpr round(PrimExpr x);
536
537 /*!
538 * \brief Calculates std::nearbyint(x)
539 * \param x The input expression.
540 * \return The result expression.
541 * This is a faster alternate to round.
542 */
543 TVM_DLL PrimExpr nearbyint(PrimExpr x);
544
545 /*!
546 * \brief Calculate trunc(x)
547 * \param x The input expression.
548 * \return The result expression.
549 */
550 TVM_DLL PrimExpr trunc(PrimExpr x);
551
552 /*!
553 * \brief Construct a large uint constant by its low 32 bits and high 32bits.
554 * \param dtype The final data type.
555 * \param low The lower 32 bits.
556 * \param high The higher 32 bits.
557 * \return The constructed expression.
558 */
559 TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high);
560
561 /*!
562 * \brief Execute a multiplication between two Q-numbers x and y
563 * followed by a right shift s. The mathematical expression is:
564 *
565 * out = round(x*y*2^-s)
566 *
567 * Please note that the two Q-numbers x and y are supposed to have
568 * the same number of fractional bits q.
569 *
570 * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format)
571 *
572 * The rounding rule is to the nearest value, rounding half up
573 * (i.e., round(x.1) = x and round (x.5) = x+1)
574 * \param x first Q-number
575 * \param y second Q-number
576 * \param q number of fractional bits in x and y. Needs to be > 0
577 * \param s integer right shift
578 * \return The constructed expression.
579 */
580 TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s);
581
582 // Intrinsic operators
583 #define TVM_DECLARE_INTRIN_UNARY(OpName) \
584 inline PrimExpr OpName(PrimExpr x) { \
585 static const Op& op = Op::Get("tir." #OpName); \
586 return tir::Call(x.dtype(), op, {x}); \
587 }
588
589 TVM_DECLARE_INTRIN_UNARY(exp);
590 TVM_DECLARE_INTRIN_UNARY(exp2);
591 TVM_DECLARE_INTRIN_UNARY(exp10);
592 TVM_DECLARE_INTRIN_UNARY(erf);
593 TVM_DECLARE_INTRIN_UNARY(tanh);
594 TVM_DECLARE_INTRIN_UNARY(sigmoid);
595 TVM_DECLARE_INTRIN_UNARY(sqrt);
596 TVM_DECLARE_INTRIN_UNARY(rsqrt);
597 TVM_DECLARE_INTRIN_UNARY(log);
598 TVM_DECLARE_INTRIN_UNARY(log2);
599 TVM_DECLARE_INTRIN_UNARY(log10);
600 TVM_DECLARE_INTRIN_UNARY(popcount);
601 TVM_DECLARE_INTRIN_UNARY(tan);
602 TVM_DECLARE_INTRIN_UNARY(cos);
603 TVM_DECLARE_INTRIN_UNARY(cosh);
604 TVM_DECLARE_INTRIN_UNARY(sin);
605 TVM_DECLARE_INTRIN_UNARY(sinh);
606 TVM_DECLARE_INTRIN_UNARY(asin);
607 TVM_DECLARE_INTRIN_UNARY(acos);
608 TVM_DECLARE_INTRIN_UNARY(atan);
609 TVM_DECLARE_INTRIN_UNARY(acosh);
610 TVM_DECLARE_INTRIN_UNARY(asinh);
611 TVM_DECLARE_INTRIN_UNARY(atanh);
612
613 #define TVM_DECLARE_INTRIN_BINARY(OpName) \
614 inline PrimExpr OpName(PrimExpr x, PrimExpr y) { \
615 static const Op& op = Op::Get("tir." #OpName); \
616 return tir::Call(x.dtype(), op, {x, y}); \
617 }
618
619 TVM_DECLARE_INTRIN_BINARY(atan2);
620 TVM_DECLARE_INTRIN_BINARY(nextafter);
621 TVM_DECLARE_INTRIN_BINARY(copysign);
622 TVM_DECLARE_INTRIN_BINARY(hypot);
623 TVM_DECLARE_INTRIN_BINARY(ldexp);
624
625 namespace tir {
626
627 /*!
628 * \brief Check if type is a pointer to a runtime element type.
629 * \param type The type to be checked.
630 * \param element_type The corresponding element type.
631 * \return The check results
632 */
IsPointerType(const Type & type,const DataType & element_type)633 inline bool IsPointerType(const Type& type, const DataType& element_type) {
634 if (!type.defined()) return false;
635 if (const auto* ptr_type = type.as<PointerTypeNode>()) {
636 if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
637 return prim_type->dtype == element_type;
638 }
639 }
640 return false;
641 }
642
643 /*!
644 * \brief Make a const value with certain data type.
645 * \param t The target type.
646 * \param value The input value
647 * \return the result expression.
648 * \tparam ValueType The constant value type
649 */
650 template <typename ValueType,
651 typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
652 inline PrimExpr make_const(DataType t, ValueType value);
653 /*!
654 * \brief Make a const zero expr.
655 * \param t The target type.
656 * \return the result expression.
657 */
658 inline PrimExpr make_zero(DataType t);
659 /*!
660 * \brief Make a constant true expression.
661 * \param lanes The number of lanes in the bool
662 * \return The result expression.
663 */
664 inline PrimExpr const_true(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 1); }
665 /*!
666 * \brief Make a constant false expression.
667 * \param lanes The number of lanes in the bool
668 * \return The result expression.
669 */
670 inline PrimExpr const_false(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 0); }
671 /*!
672 * \brief Get x as constant int expression.
673 * \param x The expression
674 * \return the address to the int expression,
675 * return nullptr, if x is not IntImm.
676 */
as_const_int(const PrimExpr & x)677 inline const int64_t* as_const_int(const PrimExpr& x) {
678 if (!x.defined()) return nullptr;
679 if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) {
680 return &(op->value);
681 } else {
682 return nullptr;
683 }
684 }
685
686 /*!
687 * \brief Check whether x is a constant integer expression.
688 * \param x The input argument
689 * \param value the value to be compared against.
690 * \return whether x is constant expression.
691 */
692 inline bool is_const_int(const PrimExpr& x, int64_t value);
693
694 /*!
695 * \brief Check whether stmt is nop.
696 * \param stmt The input statement
697 * \return whether stmt is nop
698 */
699 inline bool is_no_op(const tir::Stmt& stmt);
700
701 /*!
702 * \brief Check whether x is a constant integer 1
703 * \param x The input argument.
704 * \note This only return true for integer types.
705 * \return whether x is constant 1
706 */
is_one(const PrimExpr & x)707 inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
708
709 /*!
710 * \brief Check whether x is a constant integer 0
711 * \param x The input argument
712 * \return whether x is constant 0
713 * \note This only return true for integer types.
714 */
is_zero(const PrimExpr & x)715 inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }
716
717 /*!
718 * \brief Check whether x is an integer constant.
719 * \note This only return true for integer types.
720 * \return whether x is constant
721 */
722 inline bool is_const_int(const PrimExpr& x);
723
724 /*!
725 * \brief Check whether x is an integer/float constant.
726 * \note This only return true for integer types.
727 * \return whether x is constant
728 */
729 inline bool is_const_number(const PrimExpr& x);
730
731 /*!
732 * \brief Left fold.
733 * \param freduce The reduction function.
734 * \param init_value The initial value.
735 * \param values The values to be folded.
736 * \return The result.
737 * \tparam FReduce The type of the reduction.
738 */
739 template <typename FReduce>
740 inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values);
741
742 /*!
743 * \brief Check whether x is a constant power of two
744 * If x is power of two, write the power to the shift.
745 *
746 * \param x The input expression.
747 * \param shift The output shift if x is power of two.
748 * \return whether x is constant power of two
749 */
750 TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
751
752 // Implementation details after this
is_const_int(const PrimExpr & x)753 inline bool is_const_int(const PrimExpr& x) {
754 if (x.as<tir::IntImmNode>()) {
755 return true;
756 } else if (const auto* op = x.as<tir::BroadcastNode>()) {
757 const PrimExpr& val = op->value;
758 if (val.as<tir::IntImmNode>()) {
759 return true;
760 }
761 }
762 return false;
763 }
764
is_const_number(const PrimExpr & x)765 inline bool is_const_number(const PrimExpr& x) {
766 if (x.as<tir::IntImmNode>()) {
767 return true;
768 } else if (x.as<tir::FloatImmNode>()) {
769 return true;
770 } else if (const auto* op = x.as<tir::BroadcastNode>()) {
771 return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>());
772 }
773 return false;
774 }
775
is_positive_const(const PrimExpr & a)776 inline bool is_positive_const(const PrimExpr& a) {
777 if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
778 return op->value > 0;
779 } else {
780 return false;
781 }
782 }
783
is_negative_const(const PrimExpr & a)784 inline bool is_negative_const(const PrimExpr& a) {
785 if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
786 return op->value < 0;
787 } else {
788 return false;
789 }
790 }
791
is_const_int(const PrimExpr & x,int64_t value)792 inline bool is_const_int(const PrimExpr& x, int64_t value) {
793 if (const auto* op = x.as<tir::IntImmNode>()) {
794 return op->value == value;
795 } else if (const auto* op = x.as<tir::BroadcastNode>()) {
796 const PrimExpr& val = op->value;
797 if (const auto* opv = val.as<tir::IntImmNode>()) {
798 return opv->value == value;
799 }
800 }
801 return false;
802 }
803
is_no_op(const tir::Stmt & stmt)804 inline bool is_no_op(const tir::Stmt& stmt) {
805 if (!stmt.defined()) return true;
806 if (const auto* op = stmt.as<tir::EvaluateNode>()) {
807 return is_const_int(op->value);
808 }
809 if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
810 return op->seq.size() == 0;
811 }
812 return false;
813 }
814
815 template <typename ValueType>
MakeConstScalar(DataType t,ValueType value)816 inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
817 if (t.is_int()) return IntImm(t, static_cast<int64_t>(value));
818 if (t.is_uint()) {
819 // Use IntImm if it is a small integer
820 uint64_t uval = static_cast<uint64_t>(value);
821 if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
822 return IntImm(t, static_cast<int64_t>(value));
823 } else {
824 uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
825 uint64_t low = uval & mask;
826 uint64_t high = uval >> 32U;
827 return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
828 }
829 }
830 if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value));
831 // For now, we store const scalar values of custom datatypes within doubles; later, during the
832 // datatypes lowering pass, we will lower the value to its true representation in the format
833 // specified by the datatype.
834 // TODO(gus) when do we need to start worrying about doubles not being precise enough?
835 if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
836 return FloatImm(t, static_cast<double>(value));
837 }
838 LOG(FATAL) << "cannot make const for type " << t;
839 return PrimExpr();
840 }
841
842 template <typename ValueType, typename>
make_const(DataType t,ValueType value)843 inline PrimExpr make_const(DataType t, ValueType value) {
844 if (t.lanes() == 1) {
845 return MakeConstScalar(t, value);
846 } else {
847 return tir::Broadcast(MakeConstScalar(t.element_of(), value), t.lanes());
848 }
849 }
850
make_zero(DataType t)851 inline PrimExpr make_zero(DataType t) {
852 if (t.is_handle()) {
853 return reinterpret(t, make_const(DataType::UInt(64), 0));
854 }
855 return make_const(t, 0);
856 }
857
858 template <typename FReduce>
foldl(FReduce freduce,PrimExpr init_value,const Array<PrimExpr> & values)859 inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values) {
860 for (PrimExpr val : values) {
861 init_value = freduce(init_value, val);
862 }
863 return init_value;
864 }
865
866 } // namespace tir
867
868 // additional const expression overloading
869 #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
870 inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \
871 a = OpFunc(a, b); \
872 return a; \
873 }
874
875 #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \
876 inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \
877 inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \
878 inline PrimExpr Name(int a, const PrimExpr& b) { \
879 return Name(tir::make_const(b.dtype(), a), b); \
880 } \
881 inline PrimExpr Name(const PrimExpr& a, int b) { \
882 return Name(a, tir::make_const(a.dtype(), b)); \
883 } \
884 inline PrimExpr Name(const PrimExpr& a, double b) { \
885 return Name(a, tir::make_const(DataType::Float(64), b)); \
886 }
887
888 #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
889 inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \
890 inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); }
891
892 #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
893 inline PrimExpr Name(const PrimExpr& a, int b) { \
894 return Name(a, tir::make_const(a.dtype(), b)); \
895 } \
896 inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); }
897
898 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+);
899 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-);
900 TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*);
901 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+);
902 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-);
903 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
904 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
905 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
906 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
907 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
908 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
909 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
910 TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
911 // integer related ops
912 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexdiv);
913 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(indexmod);
914 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
915 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
916 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
917 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
918 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
919 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
920 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
921 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|);
922 TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^);
923 // logical ops
924 TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
925 TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||);
926
927 /*!
928 * \brief Helper function to raise a compiler error about division ambiguity.
929 * \note The call to this function will always results in a compiler error.
930 * \tparam TA Any class type.
931 */
932 template <typename TA>
DivAmbiguityError(const TA & a)933 inline void DivAmbiguityError(const TA& a) {
934 constexpr bool div_ambiguity = !std::is_class<TA>::value;
935 static_assert(div_ambiguity,
936 "TVM supports multiple types of integer divisions, "
937 "please call div, indexdiv/indexmod, "
938 "floordiv/floormod or truncdiv/truncmod directly "
939 "to avoid ambiguity in the code. "
940 "Checkout these functions in expr_operator.h.");
941 }
942
943 // The following code are not intended to be used in the codebase.
944 // Instead, they generate clear compiler errors that ask developers
945 // to use the specific division function.
946 // The second template argument is necessary to make sure the
947 // code compiles lazily by the compiler during invocation.
948 template <typename TB>
949 inline PrimExpr operator/(const PrimExpr& a, const TB& b) {
950 DivAmbiguityError(a);
951 return a;
952 }
953
954 template <typename TB>
955 inline PrimExpr operator/=(const PrimExpr& a, const TB& b) {
956 DivAmbiguityError(a);
957 return a;
958 }
959
960 template <typename TB>
961 inline PrimExpr operator%(const PrimExpr& a, const TB& b) {
962 DivAmbiguityError(a);
963 return a;
964 }
965 } // namespace tvm
966 #endif // TVM_TIR_OP_H_
967