1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file const_fold.h
22 * \brief Centralized location for constant folding.
23 */
24 #ifndef TVM_ARITHMETIC_CONST_FOLD_H_
25 #define TVM_ARITHMETIC_CONST_FOLD_H_
26
27 #include <tvm/ir.h>
28 #include <tvm/ir_mutator.h>
29 #include <tvm/expr_operator.h>
30 #include <algorithm>
31 #include <cmath>
32 #include "int_operator.h"
33
34 namespace tvm {
35 namespace arith {
36
37 /*!
38 * \brief Try to run binary compute with constant folding.
39 *
40 * \param a The left operand.
41 * \param b The right operand.
42 * \tparam Op The operator type.
43 *
44 * \note a and b Must already matched data types with each other.
45 * \return nullptr if constant fold fails, otherwise return folded result.
46 */
47 template<typename Op>
TryConstFold(Expr a,Expr b)48 inline Expr TryConstFold(Expr a, Expr b) {
49 return Expr();
50 }
51
52 /*!
53 * \brief Try to run unary compute with constant folding.
54 *
55 * \param a The left operand.
56 * \tparam Op The operator type.
57 *
58 * \note a and b Must already matched data types with each other.
59 * \return nullptr if constant fold fails, otherwise return folded result.
60 */
61 template<typename Op>
62 inline Expr TryConstFold(Expr a);
63
64 /*!
65 * \brief Check whether type is used to represent index.
66 *
67 * Index types are frequently used in shape computation
68 * and need to be aggressively constant-folded.
69 *
70 * \param type The type to represent index.
71 * \return the checked result.
72 */
IsIndexType(const Type & type)73 inline bool IsIndexType(const Type& type) {
74 return type.is_int() && type.lanes() == 1 &&
75 (type.bits() == 32 || type.bits() == 64);
76 }
77
78
79 #define TVM_ARITH_CONST_PROPAGATION(BODY) \
80 using ir::IntImm; \
81 using ir::UIntImm; \
82 using ir::FloatImm; \
83 const IntImm* pa = a.as<IntImm>(); \
84 const IntImm* pb = b.as<IntImm>(); \
85 const FloatImm* fa = a.as<FloatImm>(); \
86 const FloatImm* fb = b.as<FloatImm>(); \
87 BODY;
88
89
90 #define TVM_INDEX_CONST_PROPAGATION(BODY) \
91 using ir::IntImm; \
92 using ir::UIntImm; \
93 const IntImm* pa = a.as<IntImm>(); \
94 const IntImm* pb = b.as<IntImm>(); \
95 const Type& ta = a.type(); \
96 const Type& tb = b.type(); \
97 if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \
98 BODY; \
99 } \
100
101
102 // specialization of constant folders.
103 template<>
104 inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
105 TVM_ARITH_CONST_PROPAGATION({
106 const Type& rtype = a.type();
107 if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
108 if (pa && pa->value == 0) return b;
109 if (pb && pb->value == 0) return a;
110 if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
111 if (fa && fa->value == 0) return b;
112 if (fb && fb->value == 0) return a;
113 });
114 return Expr();
115 }
116
117 template<>
118 inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) {
119 TVM_ARITH_CONST_PROPAGATION({
120 const Type& rtype = a.type();
121 if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
122 if (pb && pb->value == 0) return a;
123 if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
124 if (fb && fb->value == 0) return a;
125 });
126 return Expr();
127 }
128
129 template<>
130 inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
131 TVM_ARITH_CONST_PROPAGATION({
132 const Type& rtype = a.type();
133 if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
134 if (pa) {
135 if (pa->value == 1) return b;
136 if (pa->value == 0) return a;
137 }
138 if (pb) {
139 if (pb->value == 1) return a;
140 if (pb->value == 0) return b;
141 }
142 if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
143 if (fa) {
144 if (fa->value == 1) return b;
145 if (fa->value == 0) return a;
146 }
147 if (fb) {
148 if (fb->value == 1) return a;
149 if (fb->value == 0) return b;
150 }
151 });
152 return Expr();
153 }
154
155 template<>
156 inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
157 TVM_ARITH_CONST_PROPAGATION({
158 const Type& rtype = a.type();
159 if (pa && pb) {
160 // due to division and mod can have different modes
161 // NOTE: this will assumes truc div.
162 CHECK_NE(pb->value, 0) << "Divide by zero";
163 return IntImm::make(rtype, pa->value / pb->value);
164 }
165 if (pa) {
166 if (pa->value == 0) return a;
167 }
168 if (pb) {
169 if (pb->value == 1) return a;
170 CHECK_NE(pb->value, 0) << "Divide by zero";
171 }
172 if (fa && fb && fb->value != 0) {
173 return FloatImm::make(rtype, fa->value / fb->value);
174 }
175 if (fa && fa->value == 0) return a;
176 if (fb) {
177 if (fb->value == 1) return a;
178 CHECK_NE(fb->value, 0) << "Divide by zero";
179 }
180 });
181 return Expr();
182 }
183
184 template<>
185 inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
186 TVM_INDEX_CONST_PROPAGATION({
187 const Type& rtype = a.type();
188 if (pa && pb) {
189 return IntImm::make(rtype, pa->value % pb->value);
190 }
191 if (pa) {
192 if (pa->value == 0) return a;
193 }
194 if (pb) {
195 if (pb->value == 1) return make_zero(rtype);
196 CHECK_NE(pb->value, 0) << "Divide by zero";
197 }
198 });
199 return Expr();
200 }
201
202 template<>
203 inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
204 TVM_ARITH_CONST_PROPAGATION({
205 const Type& rtype = a.type();
206 if (pa && pb) {
207 CHECK_NE(pb->value, 0) << "Divide by zero";
208 return IntImm::make(rtype, arith::floordiv(pa->value, pb->value));
209 }
210 if (pa) {
211 if (pa->value == 0) return a;
212 }
213 if (pb) {
214 if (pb->value == 1) return a;
215 CHECK_NE(pb->value, 0) << "Divide by zero";
216 }
217 if (fa && fb && fb->value != 0) {
218 return FloatImm::make(rtype, std::floor(fa->value / fb->value));
219 }
220 if (fa && fa->value == 0) return a;
221 if (fb) {
222 if (fb->value == 1) return a;
223 CHECK_NE(fb->value, 0) << "Divide by zero";
224 }
225 });
226 return Expr();
227 }
228
229 template<>
230 inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
231 TVM_INDEX_CONST_PROPAGATION({
232 const Type& rtype = a.type();
233 if (pa && pb) {
234 return IntImm::make(rtype, arith::floormod(pa->value, pb->value));
235 }
236 if (pa) {
237 if (pa->value == 0) return a;
238 }
239 if (pb) {
240 if (pb->value == 1) return make_zero(rtype);
241 CHECK_NE(pb->value, 0) << "Divide by zero";
242 }
243 });
244 return Expr();
245 }
246
247 template<>
248 inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
249 TVM_ARITH_CONST_PROPAGATION({
250 const Type& rtype = a.type();
251 if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
252 if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
253 });
254 if (a.same_as(b)) return a;
255 return Expr();
256 }
257
258 template<>
259 inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
260 TVM_ARITH_CONST_PROPAGATION({
261 const Type& rtype = a.type();
262 if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
263 if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
264 });
265 if (a.same_as(b)) return a;
266 return Expr();
267 }
268
269 template<>
270 inline Expr TryConstFold<ir::GT>(Expr a, Expr b) {
271 TVM_ARITH_CONST_PROPAGATION({
272 if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value);
273 if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value);
274 });
275 return Expr();
276 }
277
278 template<>
279 inline Expr TryConstFold<ir::GE>(Expr a, Expr b) {
280 TVM_ARITH_CONST_PROPAGATION({
281 if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value);
282 if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value);
283 });
284 return Expr();
285 }
286
287 template<>
288 inline Expr TryConstFold<ir::LT>(Expr a, Expr b) {
289 TVM_ARITH_CONST_PROPAGATION({
290 if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value);
291 if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value);
292 });
293 return Expr();
294 }
295
296 template<>
297 inline Expr TryConstFold<ir::LE>(Expr a, Expr b) {
298 TVM_ARITH_CONST_PROPAGATION({
299 if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value);
300 if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value);
301 });
302 return Expr();
303 }
304
305 template<>
306 inline Expr TryConstFold<ir::EQ>(Expr a, Expr b) {
307 TVM_ARITH_CONST_PROPAGATION({
308 if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value);
309 if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value);
310 });
311 return Expr();
312 }
313
314 template<>
315 inline Expr TryConstFold<ir::NE>(Expr a, Expr b) {
316 TVM_ARITH_CONST_PROPAGATION({
317 if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value);
318 if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value);
319 });
320 return Expr();
321 }
322
323 template<>
324 inline Expr TryConstFold<ir::And>(Expr a, Expr b) {
325 using ir::UIntImm;
326 const UIntImm* pa = a.as<UIntImm>();
327 const UIntImm* pb = b.as<UIntImm>();
328 if (pa && pa->value) return b;
329 if (pa && !pa->value) return a;
330 if (pb && pb->value) return a;
331 if (pb && !pb->value) return b;
332 return Expr();
333 }
334
335 template<>
336 inline Expr TryConstFold<ir::Or>(Expr a, Expr b) {
337 using ir::UIntImm;
338 const UIntImm* pa = a.as<UIntImm>();
339 const UIntImm* pb = b.as<UIntImm>();
340 if (pa && pa->value) return a;
341 if (pa && !pa->value) return b;
342 if (pb && pb->value) return b;
343 if (pb && !pb->value) return a;
344 return Expr();
345 }
346
347 template<>
348 inline Expr TryConstFold<ir::Not>(Expr a) {
349 using ir::UIntImm;
350 const UIntImm* pa = a.as<UIntImm>();
351 if (pa) {
352 return UIntImm::make(UInt(1), !(pa->value));
353 }
354 return Expr();
355 }
356
357 /*! \brief Helper namespace for symbolic value limits */
358 struct SymbolicLimits {
359 /*! \brief positive infinity */
360 static Expr pos_inf_;
361 /*! \brief negative infinity */
362 static Expr neg_inf_;
363 };
364
365 /*!
366 * \brief Opaque expression representing positive infinity.
367 *
368 * It can can only be used as parameter of by min/max
369 * for integer analysis and cannot be used in normal expressions.
370 *
371 * \return positive infinity.
372 */
pos_inf()373 inline Expr pos_inf() {
374 return SymbolicLimits::pos_inf_;
375 }
376
377 /*!
378 * \brief Check if value is positive infinity.
379 * \param value The value to be checked.
380 *
381 * \return The check result.
382 */
is_pos_inf(const Expr & value)383 inline bool is_pos_inf(const Expr& value) {
384 return value.same_as(SymbolicLimits::pos_inf_);
385 }
386
387 /*!
388 * \brief Opaque expression representing negative infinity.
389 *
390 * It can can only be used as parameter of by min/max
391 * for integer analysis and cannot be used in normal expressions.
392 *
393 * \return negative infinity.
394 */
neg_inf()395 inline Expr neg_inf() {
396 return SymbolicLimits::neg_inf_;
397 }
398
399 /*!
400 * \brief Check if value is negative infinity.
401 * \param value The value to be checked.
402 *
403 * \return The check result.
404 */
is_neg_inf(const Expr & value)405 inline bool is_neg_inf(const Expr& value) {
406 return value.same_as(SymbolicLimits::neg_inf_);
407 }
408
409 } // namespace arith
410 } // namespace tvm
411 #endif // TVM_ARITHMETIC_CONST_FOLD_H_
412