1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/arith/solve_linear_inequality.cc
22  * \brief Solve linear inequalities.
23  */
24 #include <tvm/arith/analyzer.h>
25 #include <tvm/arith/int_solver.h>
26 #include <tvm/arith/pattern.h>
27 #include <tvm/runtime/data_type.h>
28 #include <tvm/runtime/registry.h>
29 #include <tvm/tir/analysis.h>
30 #include <tvm/tir/expr.h>
31 #include <tvm/tir/op.h>
32 #include <tvm/tir/stmt_functor.h>
33 
34 #include "int_operator.h"
35 
36 namespace tvm {
37 namespace arith {
38 
39 using namespace tvm::runtime;
40 using namespace tvm::tir;
41 
42 #define PLUS_ONE(OP) \
43   void VisitExpr_(const OP* op) final { num_symbols_++; }
44 
45 #define PLUS_ONE_BINARY(OP)             \
46   void VisitExpr_(const OP* op) final { \
47     num_symbols_++;                     \
48     VisitExpr(op->a);                   \
49     VisitExpr(op->b);                   \
50   }
51 
52 /*!
53  * \brief Calculate the expresion complexity based on number of symbols it contains.
54  */
55 class ExprComplexity : public ExprVisitor {
56  public:
Eval(const PrimExpr & expr)57   size_t Eval(const PrimExpr& expr) {
58     VisitExpr(expr);
59     return num_symbols_;
60   }
61 
62   PLUS_ONE_BINARY(AddNode)
PLUS_ONE_BINARY(SubNode)63   PLUS_ONE_BINARY(SubNode)
64   PLUS_ONE_BINARY(MulNode)
65   PLUS_ONE_BINARY(DivNode)
66   PLUS_ONE_BINARY(ModNode)
67   PLUS_ONE_BINARY(FloorDivNode)
68   PLUS_ONE_BINARY(FloorModNode)
69   PLUS_ONE_BINARY(MinNode)
70   PLUS_ONE_BINARY(MaxNode)
71   PLUS_ONE_BINARY(EQNode)
72   PLUS_ONE_BINARY(NENode)
73   PLUS_ONE_BINARY(LTNode)
74   PLUS_ONE_BINARY(LENode)
75   PLUS_ONE_BINARY(GTNode)
76   PLUS_ONE_BINARY(GENode)
77   PLUS_ONE_BINARY(AndNode)
78   PLUS_ONE_BINARY(OrNode)
79   PLUS_ONE(VarNode)
80   PLUS_ONE(FloatImmNode)
81   PLUS_ONE(IntImmNode)
82   void VisitExpr_(const NotNode* op) final {
83     num_symbols_++;
84     VisitExpr(op->a);
85   }
86 
87  private:
88   size_t num_symbols_{0};
89 };
90 
91 struct ExprLess {
operator ()tvm::arith::ExprLess92   bool operator()(const PrimExpr& l, const PrimExpr& r) const {
93     return ExprComplexity().Eval(l) < ExprComplexity().Eval(r);
94   }
95 };
96 
DebugPrint(const std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> & current_ineq_set,const std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> & next_ineq_set,const std::vector<PrimExpr> & rest,const std::vector<std::pair<int64_t,PrimExpr>> & coef_pos,const std::vector<std::pair<int64_t,PrimExpr>> & coef_neg)97 void DebugPrint(
98     const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
99     const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& next_ineq_set,
100     const std::vector<PrimExpr>& rest, const std::vector<std::pair<int64_t, PrimExpr>>& coef_pos,
101     const std::vector<std::pair<int64_t, PrimExpr>>& coef_neg) {
102   std::cout << "Current ineq set:\n[";
103   for (auto& ineq : current_ineq_set) {
104     std::cout << ineq << ", ";
105   }
106   std::cout << "]\n";
107 
108   std::cout << "Next ineq set:\n[";
109   for (auto& ineq : next_ineq_set) {
110     std::cout << ineq << ", ";
111   }
112   std::cout << "]\n";
113 
114   std::cout << "coef_pos:\n[";
115   for (auto& coef : coef_pos) {
116     std::cout << "(" << coef.first << ", " << coef.second << "), ";
117   }
118   std::cout << "]\n";
119 
120   std::cout << "coef_neg:\n[";
121   for (auto& coef : coef_neg) {
122     std::cout << "(" << coef.first << ", " << coef.second << "), ";
123   }
124   std::cout << "]\n";
125 }
126 
127 /*!
128  * \brief normalize to the form `expr <= 0`
129  */
130 class NormalizeComparisons : public ExprMutator {
131  public:
VisitExpr_(const EQNode * op)132   PrimExpr VisitExpr_(const EQNode* op) override { return Make<EQ>(op->a, op->b); }
VisitExpr_(const NENode * op)133   PrimExpr VisitExpr_(const NENode* op) override { return Make<NE>(op->a, op->b); }
VisitExpr_(const LTNode * op)134   PrimExpr VisitExpr_(const LTNode* op) override { return Make<LT>(op->a, op->b); }
VisitExpr_(const LENode * op)135   PrimExpr VisitExpr_(const LENode* op) override { return Make<LE>(op->a, op->b); }
VisitExpr_(const GTNode * op)136   PrimExpr VisitExpr_(const GTNode* op) override { return Make<LT>(op->b, op->a); }
VisitExpr_(const GENode * op)137   PrimExpr VisitExpr_(const GENode* op) override { return Make<LE>(op->b, op->a); }
138 
139  private:
140   template <class T>
Make(const PrimExpr & a,const PrimExpr & b)141   PrimExpr Make(const PrimExpr& a, const PrimExpr& b) {
142     // rewrite LT to LE for ints
143     if (std::is_same<T, LT>::value && (a.dtype().is_int() || a.dtype().is_uint())) {
144       return LE(analyzer_.Simplify(a - b + 1), make_zero(a.dtype()));
145     }
146     return T(analyzer_.Simplify(a - b), make_zero(a.dtype()));
147   }
148   arith::Analyzer analyzer_;
149 };
150 
AddInequality(std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * inequality_set,const PrimExpr & new_ineq,Analyzer * analyzer)151 void AddInequality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* inequality_set,
152                    const PrimExpr& new_ineq, Analyzer* analyzer) {
153   if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) {
154     // redundant: follows from the vranges
155     // or has already been added
156     return;
157   }
158   if (const LENode* new_le = new_ineq.as<LENode>()) {
159     for (auto iter = inequality_set->begin(); iter != inequality_set->end();) {
160       const LENode* le = iter->as<LENode>();
161       if (le && analyzer->CanProve(new_le->a - le->a <= 0)) {
162         return;
163       } else if (le && analyzer->CanProve(le->a - new_le->a <= 0)) {
164         iter = inequality_set->erase(iter);
165       } else {
166         ++iter;
167       }
168     }
169   }
170 
171   inequality_set->insert(new_ineq);
172 }
173 
ClassifyByPolarity(const Var & var,const std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> & current_ineq_set,std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * next_ineq_set,std::vector<PrimExpr> * rest,std::vector<std::pair<int64_t,PrimExpr>> * coef_pos,std::vector<std::pair<int64_t,PrimExpr>> * coef_neg,Analyzer * analyzer)174 void ClassifyByPolarity(
175     const Var& var,
176     const std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>& current_ineq_set,
177     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* next_ineq_set,
178     std::vector<PrimExpr>* rest, std::vector<std::pair<int64_t, PrimExpr>>* coef_pos,
179     std::vector<std::pair<int64_t, PrimExpr>>* coef_neg, Analyzer* analyzer) {
180   // Take formulas from current_ineq_set and classify them according to polarity wrt var
181   // and store to coef_pos and coef_neg respectively.
182   for (const PrimExpr& ineq : current_ineq_set) {
183     if (const LENode* le = ineq.as<LENode>()) {
184       Array<PrimExpr> coef = arith::DetectLinearEquation(le->a, {var});
185       if (!coef.empty() && is_const_int(coef[0])) {
186         int64_t coef0 = *as_const_int(coef[0]);
187         if (coef0 == 0) {
188           // zero polarity, straight to next_ineq_set
189           AddInequality(next_ineq_set, ineq, analyzer);
190         } else if (coef0 > 0) {
191           coef_pos->push_back({coef0, coef[1]});
192         } else if (coef0 < 0) {
193           coef_neg->push_back({coef0, coef[1]});
194         }
195         continue;
196       }
197     } else if (const EQNode* eq = ineq.as<EQNode>()) {
198       Array<PrimExpr> coef = arith::DetectLinearEquation(eq->a, {var});
199       if (!coef.empty() && is_const_int(coef[0])) {
200         int64_t coef0 = *as_const_int(coef[0]);
201         if (coef0 == 0) {
202           // zero polarity, straight to next_ineq_set
203           AddInequality(next_ineq_set, ineq, analyzer);
204         } else if (coef0 > 0) {
205           // Equalities may be considered as pairs of two inequalities
206           coef_pos->push_back({coef0, coef[1]});
207           coef_neg->push_back({-coef0, -coef[1]});
208         } else if (coef0 < 0) {
209           coef_pos->push_back({-coef0, -coef[1]});
210           coef_neg->push_back({coef0, coef[1]});
211         }
212         continue;
213       }
214     }
215 
216     // if nothing worked, put it in rest
217     rest->push_back(ineq);
218   }
219 }
220 
MoveEquality(std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * upper_bounds,std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * lower_bounds,std::unordered_set<PrimExpr,StructuralHash,StructuralEqual> * equalities)221 void MoveEquality(std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* upper_bounds,
222                   std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* lower_bounds,
223                   std::unordered_set<PrimExpr, StructuralHash, StructuralEqual>* equalities) {
224   // those exist in both upper & lower bounds will be moved to equalities
225   for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) {
226     auto lb = lower_bounds->find(*ub);
227     if (lb != lower_bounds->end()) {
228       equalities->insert(*lb);
229       lower_bounds->erase(lb);
230       ub = upper_bounds->erase(ub);
231     } else {
232       ++ub;
233     }
234   }
235 }
236 
SolveLinearInequalities(const IntConstraints & system_to_solve)237 PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve) {
238   arith::Analyzer analyzer;
239   analyzer.Bind(system_to_solve->ranges);
240 
241   // The algorithm consists in doing the following things for each variable v
242   // - Take formulas from `current_ineq_set_to_solve` and
243   //   classify them according to polarity wrt v.
244   // - Combine each formula of positive polarity (wrt v)
245   //   with each formula of negative polarity.
246   // - Put the resulting combinations into `next_ineq_set_to_solve`
247   //   along with unclassifiable formulas.
248   // - Replace `current_ineq_set_to_solve` with `next_ineq_set_to_solve`
249   //   and move to the next variable.
250 
251   // normalized inequality
252   std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> current_ineq_set_to_solve;
253   std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> next_ineq_set_to_solve;
254   // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0
255   std::vector<std::pair<int64_t, PrimExpr>> coef_pos;
256   // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0
257   std::vector<std::pair<int64_t, PrimExpr>> coef_neg;
258 
259   // formulas we don't know what to do with
260   std::vector<PrimExpr> rest;
261 
262   // Simplify each inequality into the form `expr <= 0` and add to current formulas
263   for (const PrimExpr& ineq : system_to_solve->relations) {
264     AddInequality(&current_ineq_set_to_solve,
265                   NormalizeComparisons()(analyzer.Simplify(ineq, kSimplifyRewriteCanonicalRewrite)),
266                   &analyzer);
267   }
268 
269   Map<Var, IntGroupBounds> res_bounds;
270   for (const Var& v : system_to_solve->variables) {
271     CHECK(!res_bounds.count(v))
272         << "Variable " << v
273         << " appears more than one time in the `variables` which might be a bug";
274 
275     next_ineq_set_to_solve.clear();
276     coef_pos.clear();
277     coef_neg.clear();
278 
279     // Add bounds from vranges
280     if (system_to_solve->ranges.count(v)) {
281       const Range& range = system_to_solve->ranges[v];
282       PrimExpr range_lbound = analyzer.Simplify(range->min, kSimplifyRewriteCanonicalRewrite);
283       PrimExpr range_ubound =
284           analyzer.Simplify(range->min + range->extent - 1, kSimplifyRewriteCanonicalRewrite);
285       coef_neg.push_back({-1, range_lbound});
286       coef_pos.push_back({1, -range_ubound});
287     }
288 
289     ClassifyByPolarity(v, current_ineq_set_to_solve, &next_ineq_set_to_solve, &rest, &coef_pos,
290                        &coef_neg, &analyzer);
291 
292     // Combine each positive inequality with each negative one (by adding them together)
293     int64_t gcd_x, gcd_y;
294     for (const auto& pos : coef_pos) {
295       for (const auto& neg : coef_neg) {
296         auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, &gcd_y);
297         PrimExpr c_pos = make_const(v.dtype(), neg.first / first_gcd);
298         PrimExpr c_neg = make_const(v.dtype(), pos.first / first_gcd);
299         // eliminate the current variable
300         PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second;
301         PrimExpr new_ineq = LE(new_lhs, make_zero(pos.second.dtype()));
302         // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify
303         // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0
304         // with steps = 2 it's (y*2) - 10 <= 0
305         new_ineq =
306             NormalizeComparisons()(analyzer.Simplify(new_ineq, kSimplifyRewriteCanonicalRewrite));
307         AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer);
308       }
309     }
310 
311     // Now we have to generate resulting (in)equalities for the variable v
312 
313     // Find the common denominator in a sense
314     // We will generate formulas of the form coef_lcm*v <= bound
315     int64_t coef_lcm = 1;
316     for (const auto& pos : coef_pos) {
317       coef_lcm = LeastCommonMultiple(coef_lcm, pos.first);
318     }
319     for (const auto& neg : coef_neg) {
320       coef_lcm = LeastCommonMultiple(coef_lcm, -neg.first);
321     }
322 
323     // The resulting lower and upper bounds
324     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> upper_bounds;
325     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> lower_bounds;
326     upper_bounds.reserve(coef_pos.size());
327     lower_bounds.reserve(coef_neg.size());
328 
329     for (const auto& pos : coef_pos) {
330       PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second;
331       bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite);
332       // Don't add if any of the existing bounds is better
333       if (std::any_of(upper_bounds.begin(), upper_bounds.end(),
334                       [&bound, &analyzer](const PrimExpr& o) {
335                         return analyzer.CanProve(o - bound <= 0);
336                       })) {
337         continue;
338       }
339       // Erase all worse bounds
340       for (auto iter = upper_bounds.begin(); iter != upper_bounds.end();) {
341         if (analyzer.CanProve(*iter - bound >= 0)) {
342           iter = upper_bounds.erase(iter);
343         } else {
344           ++iter;
345         }
346       }
347       // Add the upper bound
348       upper_bounds.insert(bound);
349     }
350     for (const auto& neg : coef_neg) {
351       PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second;
352       bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite);
353       // Don't add if any of the existing bounds is better
354       if (std::any_of(lower_bounds.begin(), lower_bounds.end(),
355                       [&bound, &analyzer](const PrimExpr& o) {
356                         return analyzer.CanProve(o - bound >= 0);
357                       })) {
358         continue;
359       }
360       // Erase all worse bounds
361       for (auto iter = lower_bounds.begin(); iter != lower_bounds.end();) {
362         if (analyzer.CanProve(*iter - bound <= 0)) {
363           iter = lower_bounds.erase(iter);
364         } else {
365           ++iter;
366         }
367       }
368       // Add the lower bound
369       lower_bounds.insert(bound);
370     }
371 
372     std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> equal;
373     equal.reserve(std::min(upper_bounds.size(), lower_bounds.size()));
374     MoveEquality(&upper_bounds, &lower_bounds, &equal);
375     std::vector<PrimExpr> equal_list(equal.begin(), equal.end());
376     std::sort(equal_list.begin(), equal_list.end(), ExprLess());
377 
378     // Write it to the result.
379     IntGroupBounds bnds(make_const(v.dtype(), coef_lcm),
380                         Array<PrimExpr>(lower_bounds.begin(), lower_bounds.end()),
381                         Array<PrimExpr>(equal_list.begin(), equal_list.end()),
382                         Array<PrimExpr>(upper_bounds.begin(), upper_bounds.end()));
383     res_bounds.Set(v, bnds);
384 
385     std::swap(current_ineq_set_to_solve, next_ineq_set_to_solve);
386   }
387 
388   // Everything that is left goes to res.relations
389   Array<PrimExpr> other_conditions;
390   for (const PrimExpr& e : current_ineq_set_to_solve) {
391     PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite);
392     if (is_const_int(e_simp, 0)) {
393       // contradiction detected
394       other_conditions = {const_false()};
395       break;
396     } else if (is_const_int(e_simp, 1)) {
397       continue;
398     } else {
399       other_conditions.push_back(e_simp);
400     }
401   }
402 
403   for (const PrimExpr& e : rest) {
404     other_conditions.push_back(e);
405   }
406 
407   return {res_bounds, other_conditions};
408 }
409 
410 #ifdef _MSC_VER
411 #pragma optimize("g", off)
412 #endif
SolveInequalitiesToRange(const IntConstraints & inequalities)413 IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) {
414   // Resulting ranges will contain ranges for the new variables and for the variables that are
415   // not in the inequalities->variables but are in inequalities->ranges
416   // It will be useful when solving Jacobian axes jac_xxx)
417   Map<Var, Range> res_ranges;
418   // we get a set of equality, lower, upper bound of each variable.
419   auto solved_system = SolveLinearInequalities(inequalities);
420 
421   Map<Var, IntGroupBounds> solved_bounds = solved_system.first;
422   Array<PrimExpr> solved_other_relations = solved_system.second;
423 
424   Array<PrimExpr> res_relations;
425 
426   // this keeps being updated during determining the range of each variable.
427   Map<Var, Range> vranges;
428   for (std::pair<Var, Range> vr : inequalities->ranges) {
429     vranges.Set(vr.first, vr.second);
430   }
431 
432   // We process variables in the reverse direction to start with the most independent one.
433   // This order is needed to compute new ranges.
434   for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) {
435     arith::Analyzer analyzer;
436     analyzer.Bind(vranges);
437 
438     const Var& var = *it;
439     CHECK(solved_bounds.count(var));
440     auto bnd = solved_bounds[var];
441     if (is_one(bnd->coef) && !bnd->equal.empty()) {
442       // There is an equation of the form `v == expr`, so this variable can be completely removed.
443       // Note that we use the 0-th expression because they are ordered by complexity,
444       // so it must be the simplest one.
445       // The MSVC compiler optimization must be disabled for the expression `bnd->equal[0]` which
446       // triggers an internal compiler error.
447       Range best_range(bnd->equal[0],
448                        analyzer.Simplify(bnd->equal[0] + 1, kSimplifyRewriteCanonicalRewrite));
449       res_ranges.Set(var, best_range);
450       vranges.Set(var, best_range);
451     } else {
452       if (vranges.count(var) > 0) {
453         bnd = bnd + vranges[var];
454       }
455 
456       auto best_range = bnd.FindBestRange(vranges);
457 
458       if (best_range.defined()) {
459         if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) {
460           // range.extent <= 0 implies the input inequality system is unsolvable
461           return IntConstraints(/*variables=*/{}, /*ranges=*/{},
462                                 /*relations=*/{tir::make_zero(DataType::Bool())});
463         }
464         res_ranges.Set(var, best_range);
465         vranges.Set(var, best_range);
466       }
467     }
468   }
469 
470   // Add the original conditions to the resulting conditions
471   arith::Analyzer analyzer;
472   analyzer.Bind(vranges);
473   for (const PrimExpr& old_cond :
474        AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) {
475     if (!analyzer.CanProve(old_cond)) {
476       // those not represented in vranges (res_ranges)
477       res_relations.push_back(old_cond);
478     }
479   }
480 
481   IntConstraints system(inequalities->variables, res_ranges, res_relations);
482 
483   return system;
484 }
485 #ifdef _MSC_VER
486 #pragma optimize("g", on)
487 #endif
488 
SolveInequalitiesDeskewRange(const IntConstraints & inequalities)489 IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequalities) {
490   // Resulting ranges will contain ranges for the new variables and for the variables that are
491   // not in the inequalities->variables but are in inequalities->ranges (jac_xxx)
492   Map<Var, Range> res_ranges;
493   // we get a set of equality, lower, upper bound of each variable.
494   auto solved_system = SolveLinearInequalities(inequalities);
495   Map<Var, IntGroupBounds> solved_bounds = solved_system.first;
496   Array<PrimExpr> solved_other_relations = solved_system.second;
497 
498   arith::Analyzer analyzer;
499 
500   Map<Var, PrimExpr> res_src_to_dst;
501   Map<Var, PrimExpr> res_dst_to_src;
502   Array<Var> res_variables;
503   Array<PrimExpr> res_relations;
504 
505   // this keeps being updated during determining the range of each variable.
506   Map<Var, Range> vranges;
507   for (std::pair<Var, Range> vr : inequalities->ranges) {
508     vranges.Set(vr.first, vr.second);
509   }
510   analyzer.Bind(vranges);
511 
512   // We process variables in the reverse direction to start with the most independent one.
513   // This order is needed to compute new ranges.
514   for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) {
515     const Var& var = *it;
516     auto bnd = solved_bounds[var];
517     // Note that we replace old vars with new ones
518     bnd = bnd.Substitute(res_src_to_dst);
519 
520     if (is_one(bnd->coef) && !bnd->equal.empty()) {
521       // There is an equation of the form `v == expr`,
522       // so this variable can be completely removed.
523       // Note that we use the 0-th expression because they are ordered by complexity,
524       // so it must be the simplest one.
525       res_src_to_dst.Set(var, bnd->equal[0]);
526     } else {
527       if (vranges.count(var) > 0) {
528         bnd = bnd + vranges[var];
529       }
530 
531       auto best_range = bnd.FindBestRange(vranges);
532 
533       Var new_var = var.copy_with_suffix(".shifted");
534       if (!best_range.defined()) {
535         res_src_to_dst.Set(var, var);
536         res_dst_to_src.Set(var, var);
537         res_variables.push_back(var);
538       } else if (is_const_int(best_range->extent, 1)) {
539         // Don't create an itervar, just replace it everywhere with its min
540         res_src_to_dst.Set(var, best_range->min);
541       } else if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) {
542         // range.extent <= 0 implies the input inequality system is unsolvable
543         return IntConstraintsTransform(inequalities,
544                                        IntConstraints(
545                                            /*variables=*/{},
546                                            /*ranges=*/{},
547                                            /*relations=*/{tir::make_zero(DataType::Bool())}),
548                                        {}, {});
549       } else {
550         // created new_var starts from 0
551         res_src_to_dst.Set(var, new_var + best_range->min);
552         // Note that we are substituting old with new, so best_range contains new var,
553         // that is we have to substitute new with old in best_range here
554         res_dst_to_src.Set(new_var,
555                            analyzer.Simplify(var - Substitute(best_range->min, res_dst_to_src)));
556 
557         // Add the new var to the resulting axis
558         auto range = Range(make_zero(new_var.dtype()), best_range->extent);
559         res_variables.push_back(new_var);
560         res_ranges.Set(new_var, range);
561 
562         vranges.Set(new_var, range);
563         analyzer.Bind(new_var, range);
564       }
565     }
566   }
567 
568   // Add the original conditions (with variables substituted) to the resulting conditions
569   for (const PrimExpr& old_cond :
570        AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) {
571     PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst));
572     if (!is_const_int(new_cond, 1)) {
573       // those not represented in vranges (res_ranges)
574       res_relations.push_back(new_cond);
575     }
576   }
577 
578   // Reverse the axis so that it matches the order of the original variables
579   res_variables = Array<Var>(res_variables.rbegin(), res_variables.rend());
580 
581   IntConstraints new_inequalities(res_variables, res_ranges, res_relations);
582   IntConstraintsTransform transform(inequalities, new_inequalities, res_src_to_dst, res_dst_to_src);
583 
584   return transform;
585 }
586 
587 TVM_REGISTER_GLOBAL("arith.SolveInequalitiesAsCondition")
__anon1c4334d90302(TVMArgs args, TVMRetValue* ret) 588     .set_body([](TVMArgs args, TVMRetValue* ret) {
589       IntConstraints problem;
590       PartialSolvedInequalities ret_ineq;
591       if (args.size() == 1) {
592         problem = args[0];
593         ret_ineq = SolveLinearInequalities(problem);
594       } else if (args.size() == 3) {
595         problem = IntConstraints(args[0], args[1], args[2]);
596         ret_ineq = SolveLinearInequalities(problem);
597       } else {
598         LOG(FATAL) << "arith.SolveInequalitiesAsCondition expects 1 or 3 arguments, gets "
599                    << args.size();
600       }
601       *ret = AsConditions(problem->variables, ret_ineq.first, ret_ineq.second);
602     });
603 
__anon1c4334d90402(TVMArgs args, TVMRetValue* ret) 604 TVM_REGISTER_GLOBAL("arith.SolveInequalitiesToRange").set_body([](TVMArgs args, TVMRetValue* ret) {
605   if (args.size() == 1) {
606     *ret = SolveInequalitiesToRange(args[0]);
607   } else if (args.size() == 3) {
608     IntConstraints problem(args[0], args[1], args[2]);
609     *ret = SolveInequalitiesToRange(problem);
610   } else {
611     LOG(FATAL) << "arith.SolveInequalitiesToRange expects 1 or 3 arguments, gets " << args.size();
612   }
613 });
614 
615 TVM_REGISTER_GLOBAL("arith.SolveInequalitiesDeskewRange")
__anon1c4334d90502(TVMArgs args, TVMRetValue* ret) 616     .set_body([](TVMArgs args, TVMRetValue* ret) {
617       if (args.size() == 1) {
618         *ret = SolveInequalitiesDeskewRange(args[0]);
619       } else if (args.size() == 3) {
620         IntConstraints problem(args[0], args[1], args[2]);
621         *ret = SolveInequalitiesDeskewRange(problem);
622       } else {
623         LOG(FATAL) << "arith.SolveInequalitiesDeskewRange expects 1 or 3 arguments, gets "
624                    << args.size();
625       }
626     });
627 
628 }  // namespace arith
629 }  // namespace tvm
630