1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file const_fold.h
22  * \brief Centralized location for constant folding.
23  */
24 #ifndef TVM_ARITHMETIC_CONST_FOLD_H_
25 #define TVM_ARITHMETIC_CONST_FOLD_H_
26 
27 #include <tvm/ir.h>
28 #include <tvm/ir_mutator.h>
29 #include <tvm/expr_operator.h>
30 #include <algorithm>
31 #include <cmath>
32 #include "int_operator.h"
33 
34 namespace tvm {
35 namespace arith {
36 
37 /*!
38  * \brief Try to run binary compute with constant folding.
39  *
40  * \param a The left operand.
41  * \param b The right operand.
42  * \tparam Op The operator type.
43  *
44  * \note a and b Must already matched data types with each other.
45  * \return nullptr if constant fold fails, otherwise return folded result.
46  */
47 template<typename Op>
TryConstFold(Expr a,Expr b)48 inline Expr TryConstFold(Expr a, Expr b) {
49   return Expr();
50 }
51 
52 /*!
53  * \brief Try to run unary compute with constant folding.
54  *
55  * \param a The left operand.
56  * \tparam Op The operator type.
57  *
58  * \note a and b Must already matched data types with each other.
59  * \return nullptr if constant fold fails, otherwise return folded result.
60  */
61 template<typename Op>
62 inline Expr TryConstFold(Expr a);
63 
64 /*!
65  * \brief Check whether type is used to represent index.
66  *
67  * Index types are frequently used in shape computation
68  * and need to be aggressively constant-folded.
69  *
70  * \param type The type to represent index.
71  * \return the checked result.
72  */
IsIndexType(const Type & type)73 inline bool IsIndexType(const Type& type) {
74   return type.is_int() && type.lanes() == 1 &&
75       (type.bits() == 32 || type.bits() == 64);
76 }
77 
78 
79 #define TVM_ARITH_CONST_PROPAGATION(BODY)                               \
80   using ir::IntImm;                                                     \
81   using ir::UIntImm;                                                    \
82   using ir::FloatImm;                                                   \
83   const IntImm* pa = a.as<IntImm>();                                    \
84   const IntImm* pb = b.as<IntImm>();                                    \
85   const FloatImm* fa = a.as<FloatImm>();                                \
86   const FloatImm* fb = b.as<FloatImm>();                                \
87   BODY;
88 
89 
90 #define TVM_INDEX_CONST_PROPAGATION(BODY)                               \
91   using ir::IntImm;                                                     \
92   using ir::UIntImm;                                                    \
93   const IntImm* pa = a.as<IntImm>();                                    \
94   const IntImm* pb = b.as<IntImm>();                                    \
95   const Type& ta = a.type();                                            \
96   const Type& tb = b.type();                                            \
97   if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) {               \
98     BODY;                                                               \
99   }                                                                     \
100 
101 
102 // specialization of constant folders.
103 template<>
104 inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
105   TVM_ARITH_CONST_PROPAGATION({
106       const Type& rtype = a.type();
107       if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
108       if (pa && pa->value == 0) return b;
109       if (pb && pb->value == 0) return a;
110       if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
111       if (fa && fa->value == 0) return b;
112       if (fb && fb->value == 0) return a;
113     });
114   return Expr();
115 }
116 
117 template<>
118 inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) {
119   TVM_ARITH_CONST_PROPAGATION({
120       const Type& rtype = a.type();
121       if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
122       if (pb && pb->value == 0) return a;
123       if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
124       if (fb && fb->value == 0) return a;
125     });
126   return Expr();
127 }
128 
129 template<>
130 inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
131   TVM_ARITH_CONST_PROPAGATION({
132       const Type& rtype = a.type();
133       if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
134       if (pa) {
135         if (pa->value == 1) return b;
136         if (pa->value == 0) return a;
137       }
138       if (pb) {
139         if (pb->value == 1) return a;
140         if (pb->value == 0) return b;
141       }
142       if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
143       if (fa) {
144         if (fa->value == 1) return b;
145         if (fa->value == 0) return a;
146       }
147       if (fb) {
148         if (fb->value == 1) return a;
149         if (fb->value == 0) return b;
150       }
151     });
152   return Expr();
153 }
154 
155 template<>
156 inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
157   TVM_ARITH_CONST_PROPAGATION({
158       const Type& rtype = a.type();
159       if (pa && pb) {
160         // due to division and mod can have different modes
161         // NOTE: this will assumes truc div.
162         CHECK_NE(pb->value, 0) << "Divide by zero";
163         return IntImm::make(rtype, pa->value / pb->value);
164       }
165       if (pa) {
166         if (pa->value == 0) return a;
167       }
168       if (pb) {
169         if (pb->value == 1) return a;
170         CHECK_NE(pb->value, 0) << "Divide by zero";
171       }
172       if (fa && fb && fb->value != 0) {
173         return FloatImm::make(rtype, fa->value / fb->value);
174       }
175       if (fa && fa->value == 0) return a;
176       if (fb) {
177         if (fb->value == 1) return a;
178         CHECK_NE(fb->value, 0) << "Divide by zero";
179       }
180     });
181   return Expr();
182 }
183 
184 template<>
185 inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
186   TVM_INDEX_CONST_PROPAGATION({
187       const Type& rtype = a.type();
188       if (pa && pb) {
189         return IntImm::make(rtype, pa->value % pb->value);
190       }
191       if (pa) {
192         if (pa->value == 0) return a;
193       }
194       if (pb) {
195         if (pb->value == 1) return make_zero(rtype);
196         CHECK_NE(pb->value, 0) << "Divide by zero";
197       }
198     });
199   return Expr();
200 }
201 
202 template<>
203 inline Expr TryConstFold<ir::FloorDiv>(Expr a, Expr b) {
204   TVM_ARITH_CONST_PROPAGATION({
205       const Type& rtype = a.type();
206       if (pa && pb) {
207         CHECK_NE(pb->value, 0) << "Divide by zero";
208         return IntImm::make(rtype, arith::floordiv(pa->value, pb->value));
209       }
210       if (pa) {
211         if (pa->value == 0) return a;
212       }
213       if (pb) {
214         if (pb->value == 1) return a;
215         CHECK_NE(pb->value, 0) << "Divide by zero";
216       }
217       if (fa && fb && fb->value != 0) {
218         return FloatImm::make(rtype, std::floor(fa->value / fb->value));
219       }
220       if (fa && fa->value == 0) return a;
221       if (fb) {
222         if (fb->value == 1) return a;
223         CHECK_NE(fb->value, 0) << "Divide by zero";
224       }
225     });
226   return Expr();
227 }
228 
229 template<>
230 inline Expr TryConstFold<ir::FloorMod>(Expr a, Expr b) {
231   TVM_INDEX_CONST_PROPAGATION({
232       const Type& rtype = a.type();
233       if (pa && pb) {
234         return IntImm::make(rtype, arith::floormod(pa->value, pb->value));
235       }
236       if (pa) {
237         if (pa->value == 0) return a;
238       }
239       if (pb) {
240         if (pb->value == 1) return make_zero(rtype);
241         CHECK_NE(pb->value, 0) << "Divide by zero";
242       }
243     });
244   return Expr();
245 }
246 
247 template<>
248 inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
249   TVM_ARITH_CONST_PROPAGATION({
250       const Type& rtype = a.type();
251       if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
252       if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
253     });
254   if (a.same_as(b)) return a;
255   return Expr();
256 }
257 
258 template<>
259 inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
260   TVM_ARITH_CONST_PROPAGATION({
261       const Type& rtype = a.type();
262       if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
263       if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
264     });
265   if (a.same_as(b)) return a;
266   return Expr();
267 }
268 
269 template<>
270 inline Expr TryConstFold<ir::GT>(Expr a, Expr b) {
271   TVM_ARITH_CONST_PROPAGATION({
272       if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value);
273       if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value);
274     });
275   return Expr();
276 }
277 
278 template<>
279 inline Expr TryConstFold<ir::GE>(Expr a, Expr b) {
280   TVM_ARITH_CONST_PROPAGATION({
281       if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value);
282       if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value);
283     });
284   return Expr();
285 }
286 
287 template<>
288 inline Expr TryConstFold<ir::LT>(Expr a, Expr b) {
289   TVM_ARITH_CONST_PROPAGATION({
290       if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value);
291       if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value);
292     });
293   return Expr();
294 }
295 
296 template<>
297 inline Expr TryConstFold<ir::LE>(Expr a, Expr b) {
298   TVM_ARITH_CONST_PROPAGATION({
299       if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value);
300       if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value);
301     });
302   return Expr();
303 }
304 
305 template<>
306 inline Expr TryConstFold<ir::EQ>(Expr a, Expr b) {
307   TVM_ARITH_CONST_PROPAGATION({
308       if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value);
309       if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value);
310     });
311   return Expr();
312 }
313 
314 template<>
315 inline Expr TryConstFold<ir::NE>(Expr a, Expr b) {
316   TVM_ARITH_CONST_PROPAGATION({
317       if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value);
318       if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value);
319     });
320   return Expr();
321 }
322 
323 template<>
324 inline Expr TryConstFold<ir::And>(Expr a, Expr b) {
325   using ir::UIntImm;
326   const UIntImm* pa = a.as<UIntImm>();
327   const UIntImm* pb = b.as<UIntImm>();
328   if (pa && pa->value) return b;
329   if (pa && !pa->value) return a;
330   if (pb && pb->value) return a;
331   if (pb && !pb->value) return b;
332   return Expr();
333 }
334 
335 template<>
336 inline Expr TryConstFold<ir::Or>(Expr a, Expr b) {
337   using ir::UIntImm;
338   const UIntImm* pa = a.as<UIntImm>();
339   const UIntImm* pb = b.as<UIntImm>();
340   if (pa && pa->value) return a;
341   if (pa && !pa->value) return b;
342   if (pb && pb->value) return b;
343   if (pb && !pb->value) return a;
344   return Expr();
345 }
346 
347 template<>
348 inline Expr TryConstFold<ir::Not>(Expr a) {
349   using ir::UIntImm;
350   const UIntImm* pa = a.as<UIntImm>();
351   if (pa) {
352     return UIntImm::make(UInt(1), !(pa->value));
353   }
354   return Expr();
355 }
356 
357 /*! \brief Helper namespace for symbolic value limits */
358 struct SymbolicLimits {
359   /*! \brief positive infinity */
360   static Expr pos_inf_;
361   /*! \brief negative infinity */
362   static Expr neg_inf_;
363 };
364 
365 /*!
366  * \brief Opaque expression representing positive infinity.
367  *
368  *  It can can only be used as parameter of by min/max
369  *  for integer analysis and cannot be used in normal expressions.
370  *
371  * \return positive infinity.
372  */
pos_inf()373 inline Expr pos_inf() {
374   return SymbolicLimits::pos_inf_;
375 }
376 
377 /*!
378  * \brief Check if value is positive infinity.
379  * \param value The value to be checked.
380  *
381  * \return The check result.
382  */
is_pos_inf(const Expr & value)383 inline bool is_pos_inf(const Expr& value) {
384   return value.same_as(SymbolicLimits::pos_inf_);
385 }
386 
387 /*!
388  * \brief Opaque expression representing negative infinity.
389  *
390  *  It can can only be used as parameter of by min/max
391  *  for integer analysis and cannot be used in normal expressions.
392  *
393  * \return negative infinity.
394  */
neg_inf()395 inline Expr neg_inf() {
396   return SymbolicLimits::neg_inf_;
397 }
398 
399 /*!
400  * \brief Check if value is negative infinity.
401  * \param value The value to be checked.
402  *
403  * \return The check result.
404  */
is_neg_inf(const Expr & value)405 inline bool is_neg_inf(const Expr& value) {
406   return value.same_as(SymbolicLimits::neg_inf_);
407 }
408 
409 }  // namespace arith
410 }  // namespace tvm
411 #endif  // TVM_ARITHMETIC_CONST_FOLD_H_
412