1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file expr_operator.cc
22 */
23 #include <tvm/base.h>
24 #include <tvm/ir.h>
25 #include <tvm/expr_operator.h>
26 #include <cmath>
27 // Centralized header for constant folders.
28 #include "../arithmetic/const_fold.h"
29
30 namespace tvm {
31
32 // simple cast that only checks if type matches and cast
SimpleCast(const Type & t,Expr value)33 inline Expr SimpleCast(const Type& t, Expr value) {
34 if (value.type() == t) return value;
35 return ir::Cast::make(t, value);
36 }
37
38 // The public function with a quick checking path.
BinaryOpMatchTypes(Expr & lhs,Expr & rhs)39 void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*)
40 if (lhs.type() == rhs.type()) return;
41 Type ltype = lhs.type();
42 Type rtype = rhs.type();
43 if (ltype.lanes() == 1 && rtype.lanes() != 1) {
44 lhs = ir::Broadcast::make(lhs, rtype.lanes());
45 } else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
46 rhs = ir::Broadcast::make(rhs, ltype.lanes());
47 } else {
48 CHECK(ltype.lanes() == rtype.lanes())
49 << "Cannot match type " << ltype << " vs " << rtype;
50 }
51 if (lhs.type() == rhs.type()) return;
52 // Only do very simple type coversion
53 // int->float, int(32)->int(64)
54 // require the types to be relatively consistent
55 // This will the reduce amount code generated by operators
56 // and also help user to find potential type conversion problems.
57 if (!lhs.type().is_float() && rhs.type().is_float()) {
58 // int->float
59 lhs = cast(rhs.type(), lhs);
60 } else if (lhs.type().is_float() && !rhs.type().is_float()) {
61 // int->float
62 rhs = cast(lhs.type(), rhs);
63 } else if ((lhs.type().is_int() && rhs.type().is_int()) ||
64 (lhs.type().is_uint() && rhs.type().is_uint())) {
65 // promote int to higher bits
66 if (lhs.type().bits() < rhs.type().bits()) {
67 lhs = cast(rhs.type(), lhs);
68 } else {
69 rhs = cast(lhs.type(), rhs);
70 }
71 } else if ((lhs.type().is_int() && rhs.type().is_uint()) ||
72 (lhs.type().is_uint() && rhs.type().is_int())) {
73 int bits = std::max(lhs.type().bits(), rhs.type().bits());
74 lhs = SimpleCast(Int(bits, lhs.type().lanes()), lhs);
75 rhs = SimpleCast(Int(bits, rhs.type().lanes()), rhs);
76 } else {
77 LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
78 }
79 }
80
81
82 template<typename ValueType>
ConstPowerHelper(ValueType val,int * shift)83 inline bool ConstPowerHelper(ValueType val, int *shift) {
84 if (val <= 0) return false;
85 shift[0] = 0;
86 while (val != 0) {
87 if (val & 1) {
88 return (val == 1);
89 }
90 ++shift[0];
91 val = val >> 1;
92 }
93 return true;
94 }
95
is_const_power_of_two_integer(const Expr & x,int * shift)96 bool is_const_power_of_two_integer(const Expr& x, int* shift) {
97 if (const auto* op = x.as<ir::IntImm>()) {
98 return ConstPowerHelper(op->value, shift);
99 } else if (const auto* op = x.as<ir::UIntImm>()) {
100 return ConstPowerHelper(op->value, shift);
101 } else {
102 return false;
103 }
104 }
105
cast(const Type & t,Expr value)106 Expr cast(const Type& t, Expr value) {
107 using ir::IntImm;
108 using ir::UIntImm;
109 using ir::FloatImm;
110 if (value.type() == t) return value;
111 // const fold IntImm as they are used in index computations
112 if (t.lanes() == 1) {
113 if (const IntImm* op = value.as<IntImm>()) {
114 return make_const(t, op->value);
115 } else if (const UIntImm* op = value.as<UIntImm>()) {
116 return make_const(t, op->value);
117 } else if (const FloatImm* op = value.as<FloatImm>()) {
118 return make_const(t, op->value);
119 }
120 return ir::Cast::make(t, value);
121 } else {
122 if (value.type().lanes() == 1) {
123 // manually unroll cast
124 Type vtype = t.element_of();
125 if (value.type() != vtype) {
126 if (const IntImm* op = value.as<IntImm>()) {
127 value = make_const(vtype, op->value);
128 } else if (const UIntImm* op = value.as<UIntImm>()) {
129 return make_const(t, op->value);
130 } else if (const FloatImm* op = value.as<FloatImm>()) {
131 value = make_const(vtype, op->value);
132 } else {
133 value = ir::Cast::make(vtype, value);
134 }
135 }
136 return ir::Broadcast::make(value, t.lanes());
137 } else {
138 CHECK(value.type().lanes() == t.lanes());
139 return ir::Cast::make(t, value);
140 }
141 }
142 }
143
reinterpret(const Type & t,Expr value)144 Expr reinterpret(const Type& t, Expr value) {
145 if (value.type() == t) return value;
146 return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic);
147 }
148
operator +(Expr a,Expr b)149 Expr operator+(Expr a, Expr b) {
150 BinaryOpMatchTypes(a, b);
151 Expr ret = arith::TryConstFold<ir::Add>(a, b);
152 if (ret.defined()) return ret;
153 return ir::Add::make(a, b);
154 }
155
156 // negation
operator -(Expr a)157 Expr operator-(Expr a) {
158 using ir::IntImm;
159 using ir::FloatImm;
160 const IntImm* pa = a.as<IntImm>();
161 const FloatImm* fa = a.as<FloatImm>();
162 if (pa) return ir::IntImm::make(a.type(), -pa->value);
163 if (fa) return ir::FloatImm::make(a.type(), -fa->value);
164 return make_zero(a.type()) - a;
165 }
166
operator -(Expr a,Expr b)167 Expr operator-(Expr a, Expr b) {
168 BinaryOpMatchTypes(a, b);
169 Expr ret = arith::TryConstFold<ir::Sub>(a, b);
170 if (ret.defined()) return ret;
171 return ir::Sub::make(a, b);
172 }
173
operator *(Expr a,Expr b)174 Expr operator*(Expr a, Expr b) {
175 BinaryOpMatchTypes(a, b);
176 Expr ret = arith::TryConstFold<ir::Mul>(a, b);
177 if (ret.defined()) return ret;
178 return ir::Mul::make(a, b);
179 }
180
div(Expr a,Expr b)181 Expr div(Expr a, Expr b) {
182 BinaryOpMatchTypes(a, b);
183 Expr ret = arith::TryConstFold<ir::Div>(a, b);
184 if (ret.defined()) return ret;
185 return ir::Div::make(a, b);
186 }
187
truncdiv(Expr a,Expr b)188 Expr truncdiv(Expr a, Expr b) {
189 CHECK(a.type().is_int() || a.type().is_uint());
190 CHECK(b.type().is_int() || b.type().is_uint());
191 return div(a, b);
192 }
193
truncmod(Expr a,Expr b)194 Expr truncmod(Expr a, Expr b) {
195 BinaryOpMatchTypes(a, b);
196 Expr ret = arith::TryConstFold<ir::Mod>(a, b);
197 if (ret.defined()) return ret;
198 return ir::Mod::make(a, b);
199 }
200
operator /(Expr a,Expr b)201 Expr operator/(Expr a, Expr b) {
202 return div(a, b);
203 }
204
operator %(Expr a,Expr b)205 Expr operator%(Expr a, Expr b) {
206 return truncmod(a, b);
207 }
208
209 // TODO(tqchen): switch to floordiv
indexdiv(Expr a,Expr b)210 Expr indexdiv(Expr a, Expr b) {
211 return floordiv(a, b);
212 }
213
indexmod(Expr a,Expr b)214 Expr indexmod(Expr a, Expr b) {
215 return floormod(a, b);
216 }
217
floordiv(Expr a,Expr b)218 Expr floordiv(Expr a, Expr b) {
219 CHECK(a.type().is_int() || a.type().is_uint());
220 CHECK(b.type().is_int() || b.type().is_uint());
221 BinaryOpMatchTypes(a, b);
222 Expr ret = arith::TryConstFold<ir::FloorDiv>(a, b);
223 if (ret.defined()) return ret;
224 return ir::FloorDiv::make(a, b);
225 }
226
floormod(Expr a,Expr b)227 Expr floormod(Expr a, Expr b) {
228 CHECK(a.type().is_int() || a.type().is_uint());
229 CHECK(b.type().is_int() || b.type().is_uint());
230 BinaryOpMatchTypes(a, b);
231 Expr ret = arith::TryConstFold<ir::FloorMod>(a, b);
232 if (ret.defined()) return ret;
233 return ir::FloorMod::make(a, b);
234 }
235
min(Expr a,Expr b)236 Expr min(Expr a, Expr b) {
237 // inf-aware simplificaiton
238 using arith::is_pos_inf;
239 using arith::is_neg_inf;
240 if (is_pos_inf(a)) return b;
241 if (is_neg_inf(a)) return a;
242 if (is_pos_inf(b)) return a;
243 if (is_neg_inf(b)) return b;
244 BinaryOpMatchTypes(a, b);
245 Expr ret = arith::TryConstFold<ir::Min>(a, b);
246 if (ret.defined()) return ret;
247 return ir::Min::make(a, b);
248 }
249
max(Expr a,Expr b)250 Expr max(Expr a, Expr b) {
251 // inf-aware simplificaiton
252 using arith::is_pos_inf;
253 using arith::is_neg_inf;
254 if (is_pos_inf(a)) return a;
255 if (is_neg_inf(a)) return b;
256 if (is_pos_inf(b)) return b;
257 if (is_neg_inf(b)) return a;
258 BinaryOpMatchTypes(a, b);
259 Expr ret = arith::TryConstFold<ir::Max>(a, b);
260 if (ret.defined()) return ret;
261 return ir::Max::make(a, b);
262 }
263
if_then_else(Expr cond,Expr true_value,Expr false_value)264 Expr if_then_else(Expr cond, Expr true_value, Expr false_value) {
265 using ir::IntImm;
266 using ir::UIntImm;
267 CHECK(cond.type() == Bool(1))
268 << "if_then_else only accept the condition to be boolean type.";
269 BinaryOpMatchTypes(true_value, false_value);
270 if (const UIntImm* op = cond.as<UIntImm>()) {
271 if (op->value != 0) {
272 return true_value;
273 } else {
274 return false_value;
275 }
276 } else if (const IntImm* op = cond.as<IntImm>()) {
277 if (op->value != 0) {
278 return true_value;
279 } else {
280 return false_value;
281 }
282 }
283 return ir::Call::make(
284 true_value.type(),
285 ir::intrinsic::tvm_if_then_else,
286 {cond, true_value, false_value},
287 ir::Call::PureIntrinsic);
288 }
289
likely(Expr cond)290 Expr likely(Expr cond) {
291 if (is_const(cond)) return cond;
292 return ir::Call::make(cond.type(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic);
293 }
294
operator >(Expr a,Expr b)295 Expr operator>(Expr a, Expr b) {
296 BinaryOpMatchTypes(a, b);
297 Expr ret = arith::TryConstFold<ir::GT>(a, b);
298 if (ret.defined()) return ret;
299 return ir::GT::make(a, b);
300 }
301
operator >=(Expr a,Expr b)302 Expr operator>=(Expr a, Expr b) {
303 BinaryOpMatchTypes(a, b);
304 Expr ret = arith::TryConstFold<ir::GE>(a, b);
305 if (ret.defined()) return ret;
306 return ir::GE::make(a, b);
307 }
308
operator <(Expr a,Expr b)309 Expr operator<(Expr a, Expr b) {
310 BinaryOpMatchTypes(a, b);
311 Expr ret = arith::TryConstFold<ir::LT>(a, b);
312 if (ret.defined()) return ret;
313 return ir::LT::make(a, b);
314 }
315
operator <=(Expr a,Expr b)316 Expr operator<=(Expr a, Expr b) {
317 BinaryOpMatchTypes(a, b);
318 Expr ret = arith::TryConstFold<ir::LE>(a, b);
319 if (ret.defined()) return ret;
320 return ir::LE::make(a, b);
321 }
322
operator ==(Expr a,Expr b)323 Expr operator==(Expr a, Expr b) {
324 BinaryOpMatchTypes(a, b);
325 Expr ret = arith::TryConstFold<ir::EQ>(a, b);
326 if (ret.defined()) return ret;
327 return ir::EQ::make(a, b);
328 }
329
operator !=(Expr a,Expr b)330 Expr operator!=(Expr a, Expr b) {
331 BinaryOpMatchTypes(a, b);
332 Expr ret = arith::TryConstFold<ir::NE>(a, b);
333 if (ret.defined()) return ret;
334 return ir::NE::make(a, b);
335 }
336
operator &&(Expr a,Expr b)337 Expr operator&&(Expr a, Expr b) {
338 CHECK(a.type().is_bool());
339 CHECK(b.type().is_bool());
340 Expr ret = arith::TryConstFold<ir::And>(a, b);
341 if (ret.defined()) return ret;
342 return ir::And::make(a, b);
343 }
344
operator ||(Expr a,Expr b)345 Expr operator||(Expr a, Expr b) {
346 CHECK(a.type().is_bool());
347 CHECK(b.type().is_bool());
348 Expr ret = arith::TryConstFold<ir::Or>(a, b);
349 if (ret.defined()) return ret;
350 return ir::Or::make(a, b);
351 }
352
operator !(Expr a)353 Expr operator!(Expr a) {
354 CHECK(a.type().is_bool());
355 Expr ret = arith::TryConstFold<ir::Not>(a);
356 if (ret.defined()) return ret;
357 return ir::Not::make(a);
358 }
359
operator >>(Expr a,Expr b)360 Expr operator>>(Expr a, Expr b) {
361 BinaryOpMatchTypes(a, b);
362 TVM_INDEX_CONST_PROPAGATION({
363 const Type& rtype = a.type();
364 if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value));
365 if (pb) {
366 if (pb->value == 0) return a;
367 }
368 });
369 return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic);
370 }
371
operator <<(Expr a,Expr b)372 Expr operator<<(Expr a, Expr b) {
373 BinaryOpMatchTypes(a, b);
374 TVM_INDEX_CONST_PROPAGATION({
375 const Type& rtype = a.type();
376 if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value));
377 if (pb) {
378 if (pb->value == 0) return a;
379 }
380 });
381 return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic);
382 }
383
operator &(Expr a,Expr b)384 Expr operator&(Expr a, Expr b) {
385 BinaryOpMatchTypes(a, b);
386 TVM_INDEX_CONST_PROPAGATION({
387 const Type& rtype = a.type();
388 if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value));
389 });
390 return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic);
391 }
392
operator |(Expr a,Expr b)393 Expr operator|(Expr a, Expr b) {
394 BinaryOpMatchTypes(a, b);
395 TVM_INDEX_CONST_PROPAGATION({
396 const Type& rtype = a.type();
397 if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value));
398 });
399 return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic);
400 }
401
operator ^(Expr a,Expr b)402 Expr operator^(Expr a, Expr b) {
403 BinaryOpMatchTypes(a, b);
404 TVM_INDEX_CONST_PROPAGATION({
405 const Type& rtype = a.type();
406 if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value));
407 });
408 return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic);
409 }
410
operator ~(Expr a)411 Expr operator~(Expr a) {
412 CHECK(a.type().is_int() || a.type().is_uint());
413 return ir::Call::make(a.type(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic);
414 }
415
pow(Expr x,Expr y)416 Expr pow(Expr x, Expr y) {
417 BinaryOpMatchTypes(x, y);
418 CHECK(x.type().is_float()) << "power only applies to float";
419 return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
420 }
421
abs(Expr x)422 Expr abs(Expr x) {
423 if (x.type().is_int()) {
424 using ir::IntImm;
425 const IntImm* px = x.as<IntImm>();
426 if (px) {
427 return ir::IntImm::make(x.type(), std::abs(px->value));
428 }
429 return ir::Select::make(x >= make_zero(x.type()), x, -x);
430 } else if (x.type().is_float()) {
431 using ir::FloatImm;
432 const FloatImm* fx = x.as<FloatImm>();
433 if (fx) {
434 return ir::FloatImm::make(x.type(), std::fabs(fx->value));
435 }
436 return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic);
437 } else if (x.type().is_uint()) {
438 return x;
439 } else {
440 LOG(FATAL) << "Data type " << x.type()
441 <<" not supported for absolute op. Skipping absolute op...";
442 return x;
443 }
444 }
445
isnan(Expr x)446 Expr isnan(Expr x) {
447 Type t = Bool(x.type().lanes());
448 if (x.type().is_int() || x.type().is_uint()) {
449 return make_const(t, false);
450 } else if (x.type().is_float()) {
451 using ir::FloatImm;
452 const FloatImm* fx = x.as<FloatImm>();
453 if (fx) {
454 return make_const(t, std::isnan(fx->value));
455 }
456 if (x.type().bits() == 16) {
457 return ir::Call::make(t, ir::Call::isnan,
458 {cast(Float(32, t.lanes()), std::move(x))},
459 ir::Call::PureIntrinsic);
460 } else {
461 return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic);
462 }
463 } else {
464 LOG(FATAL) << "Data type " << x.type()
465 <<" not supported for isnan op. Skipping isnan op...";
466 return x;
467 }
468 }
469
sum(Expr source,Array<IterVar> rdom)470 Expr sum(Expr source, Array<IterVar> rdom) {
471 Var x("x", source.type()), y("y", source.type());
472 Expr result = ir::Add::make(x, y);
473 Expr identity_element = make_zero(source.type());
474 ir::CommReducer combiner =
475 ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
476 return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
477 }
478
all(Expr source,Array<IterVar> rdom)479 Expr all(Expr source, Array<IterVar> rdom) {
480 CHECK(source.type().is_bool());
481 Var x("x", source.type()), y("y", source.type());
482 Expr result = ir::And::make(x, y);
483 Expr identity_element = make_const(source.type(), true);
484 ir::CommReducer combiner =
485 ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
486 return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
487 }
488
any(Expr source,Array<IterVar> rdom)489 Expr any(Expr source, Array<IterVar> rdom) {
490 CHECK(source.type().is_bool());
491 Var x("x", source.type()), y("y", source.type());
492 Expr result = ir::Or::make(x, y);
493 Expr identity_element = make_const(source.type(), false);
494 ir::CommReducer combiner =
495 ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
496 return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
497 }
498
max(Expr source,Array<IterVar> rdom)499 Expr max(Expr source, Array<IterVar> rdom) {
500 Var x("x", source.type()), y("y", source.type());
501 Expr result = ir::Max::make(x, y);
502 Expr identity_element = source.type().min();
503 ir::CommReducer combiner =
504 ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
505 return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
506 }
507
min(Expr source,Array<IterVar> rdom)508 Expr min(Expr source, Array<IterVar> rdom) {
509 Var x("x", source.type()), y("y", source.type());
510 Expr result = ir::Min::make(x, y);
511 Expr identity_element = source.type().max();
512 ir::CommReducer combiner =
513 ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
514 return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
515 }
516
prod(Expr source,Array<IterVar> rdom)517 Expr prod(Expr source, Array<IterVar> rdom) {
518 Var x("x", source.type()), y("y", source.type());
519 Expr result = ir::Mul::make(x, y);
520 Expr identity_element = make_const(source.type(), 1);
521 ir::CommReducer combiner =
522 ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
523 return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
524 }
525
fmod(Expr x,Expr y)526 Expr fmod(Expr x, Expr y) {
527 BinaryOpMatchTypes(x, y);
528 CHECK(x.type().is_float()) << "fmod only applies to float";
529 return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic);
530 }
531
floor(Expr x)532 Expr floor(Expr x) {
533 using ir::FloatImm;
534 const FloatImm* fx = x.as<FloatImm>();
535 if (fx) return FloatImm::make(x.type(), std::floor(fx->value));
536 return ir::Call::make(x.type(), "floor", {x}, ir::Call::PureIntrinsic);
537 }
538
ceil(Expr x)539 Expr ceil(Expr x) {
540 using ir::FloatImm;
541 const FloatImm* fx = x.as<FloatImm>();
542 if (fx) return FloatImm::make(x.type(), std::ceil(fx->value));
543 return ir::Call::make(x.type(), "ceil", {x}, ir::Call::PureIntrinsic);
544 }
545
round(Expr x)546 Expr round(Expr x) {
547 using ir::FloatImm;
548 const FloatImm* fx = x.as<FloatImm>();
549 if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
550 return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic);
551 }
552
nearbyint(Expr x)553 Expr nearbyint(Expr x) {
554 using ir::FloatImm;
555 const FloatImm* fx = x.as<FloatImm>();
556 if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value));
557 return ir::Call::make(x.type(), "nearbyint", {x}, ir::Call::PureIntrinsic);
558 }
559
trunc(Expr x)560 Expr trunc(Expr x) {
561 using ir::FloatImm;
562 const FloatImm* fx = x.as<FloatImm>();
563 if (fx) {
564 return FloatImm::make(x.type(), (fx->value < 0 ? std::ceil(fx->value) :
565 std::floor(fx->value)));
566 }
567 return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic);
568 }
569
570 } // namespace tvm
571